This commit is contained in:
George Hotz
2026-04-29 21:55:30 +00:00
parent 0095b91ce6
commit f8a66639a6

View File

@@ -209,7 +209,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property
def _shape(self) -> tuple[sint, ...]|None:
if self.dtype.count > 1 and self.op not in GroupOp.Movement and self.op is not Ops.GEP: return (self.dtype.count,)
if self.dtype.count > 1: return (self.dtype.count,)
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
@@ -319,10 +319,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
input_shapes = [x._shape for x in self.src if x._shape is not None]
if len(input_shapes) == 0: return None
non_scalar_shapes = [s for s in input_shapes if s != ()]
if len(non_scalar_shapes) == 0: return ()
if not all_same(non_scalar_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
return non_scalar_shapes[0]
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
return input_shapes[0]
# all Ops must be explicitly handled
raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")