This commit is contained in:
George Hotz
2026-04-29 18:20:20 -07:00
parent 4bf0c35300
commit cfdff84df0
2 changed files with 6 additions and 2 deletions

View File

@@ -546,7 +546,7 @@ def split_store(x:UOp) -> UOp|None:
# if we have any open ranges here, we don't split
if x.ranges: return None
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
if x.op is Ops.STORE and x.src[0]._shape is not None: return None
#if x.op is Ops.STORE and x.src[0]._shape is not None: return None
# local kernel rewrite
lctx = LocalAddBufferContext()

View File

@@ -100,6 +100,8 @@ class UOpMetaClass(type):
buffers[created] = _buffer
if SPEC > 1:
from tinygrad.uop.spec import full_spec, test_pyrender
# NOTE: while indexing, the STORE is wrong when only one side is updated
#shp = created._shape
if SPEC > 2: test_pyrender(created)
with Context(CHECK_OOB=0): fret = cast(bool|None, full_spec.rewrite(created))
if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}")
@@ -237,7 +239,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.GEP: return (len(self.arg),) if len(self.arg) > 1 else ()
case Ops.STACK: return (len(self.src),)
case Ops.INDEX:
shp = []
shp:list[sint] = []
# NOTE: the acc buffer can have a dtype with count, we need it back here
if self.src[0].dtype.count > 1: shp.append(self.src[0].dtype.count)
for s in self.src[1:]: shp.extend(list(s.shape))
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]