diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 87b6729fb4..7de5014cbe 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -696,15 +696,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.STACK: return self.src[i].sintify() case _: raise RuntimeError(f"no sgep on {self.op}") + @functools.cached_property + def as_shape(self) -> tuple[sint, ...]: + return tuple(ssimplify(self.sgep(i)) for i in range(max(self.dtype.count, len(self.src)))) + @functools.cached_property def marg(self): match self.op: - case Ops.RESHAPE | Ops.EXPAND: return tuple(ssimplify(self.src[1].sgep(i)) for i in range(self.src[1].dtype.count)) - case Ops.PAD | Ops.SHRINK: - # this is like broadcasting for shapes - return tuple(((ssimplify(self.src[1]) if self.src[1].shape == () else self.src[1].sgep(i)), - (ssimplify(self.src[2]) if self.src[2].shape == () else self.src[2].sgep(i))) - for i in range(max(self.src[1].dtype.count, self.src[2].dtype.count))) + case Ops.RESHAPE | Ops.EXPAND: return self.src[1].as_shape + case Ops.PAD | Ops.SHRINK: return tuple(zip(self.src[1].as_shape, self.src[2].as_shape)) case Ops.PERMUTE | Ops.FLIP: return self.arg case _: raise RuntimeError(f"{self.op} is not a MovementOp")