mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
fixes
This commit is contained in:
@@ -546,7 +546,7 @@ def split_store(x:UOp) -> UOp|None:
|
||||
# if we have any open ranges here, we don't split
|
||||
if x.ranges: return None
|
||||
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
|
||||
if x.op is Ops.STORE and x.src[0]._shape is not None: return None
|
||||
#if x.op is Ops.STORE and x.src[0]._shape is not None: return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
|
||||
@@ -100,6 +100,8 @@ class UOpMetaClass(type):
|
||||
buffers[created] = _buffer
|
||||
if SPEC > 1:
|
||||
from tinygrad.uop.spec import full_spec, test_pyrender
|
||||
# NOTE: while indexing, the STORE is wrong when only one side is updated
|
||||
#shp = created._shape
|
||||
if SPEC > 2: test_pyrender(created)
|
||||
with Context(CHECK_OOB=0): fret = cast(bool|None, full_spec.rewrite(created))
|
||||
if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}")
|
||||
@@ -237,7 +239,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.GEP: return (len(self.arg),) if len(self.arg) > 1 else ()
|
||||
case Ops.STACK: return (len(self.src),)
|
||||
case Ops.INDEX:
|
||||
shp = []
|
||||
shp:list[sint] = []
|
||||
# NOTE: the acc buffer can have a dtype with count, we need it back here
|
||||
if self.src[0].dtype.count > 1: shp.append(self.src[0].dtype.count)
|
||||
for s in self.src[1:]: shp.extend(list(s.shape))
|
||||
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user