From f8a66639a61fa3fd695cdbc823f8a8f8e9102929 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 29 Apr 2026 21:55:30 +0000 Subject: [PATCH] junk --- tinygrad/uop/ops.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 97c810c8fe..3eb4d97248 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -209,7 +209,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): @recursive_property def _shape(self) -> tuple[sint, ...]|None: - if self.dtype.count > 1 and self.op not in GroupOp.Movement and self.op is not Ops.GEP: return (self.dtype.count,) + if self.dtype.count > 1: return (self.dtype.count,) match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ @@ -319,10 +319,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}): input_shapes = [x._shape for x in self.src if x._shape is not None] if len(input_shapes) == 0: return None - non_scalar_shapes = [s for s in input_shapes if s != ()] - if len(non_scalar_shapes) == 0: return () - if not all_same(non_scalar_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}") - return non_scalar_shapes[0] + if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}") + return input_shapes[0] # all Ops must be explicitly handled raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")