mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
junk
This commit is contained in:
@@ -209,7 +209,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
|
||||
@recursive_property
|
||||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
if self.dtype.count > 1 and self.op not in GroupOp.Movement and self.op is not Ops.GEP: return (self.dtype.count,)
|
||||
if self.dtype.count > 1: return (self.dtype.count,)
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
@@ -319,10 +319,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
non_scalar_shapes = [s for s in input_shapes if s != ()]
|
||||
if len(non_scalar_shapes) == 0: return ()
|
||||
if not all_same(non_scalar_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||
return non_scalar_shapes[0]
|
||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||
return input_shapes[0]
|
||||
|
||||
# all Ops must be explicitly handled
|
||||
raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")
|
||||
|
||||
Reference in New Issue
Block a user