mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
param args aren't realized
This commit is contained in:
@@ -29,7 +29,7 @@ class ParamArg:
|
||||
device: str|tuple[str, ...]|None = None
|
||||
def __repr__(self):
|
||||
fields = (("vmin_vmax", None), ("name", None), ("addrspace", AddrSpace.GLOBAL), ("axis", None), ("device", None))
|
||||
args = [str(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default]
|
||||
args = [repr(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default]
|
||||
return f"ParamArg({', '.join(args)})"
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
@@ -150,7 +150,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
except AttributeError: pass
|
||||
def __reduce__(self):
|
||||
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
|
||||
if self.op is Ops.BUFFER and not isinstance(self.arg, ParamArg) and self.realized is not None: args.append(self.realized)
|
||||
if self.op is Ops.BUFFER and self.realized is not None: args.append(self.realized)
|
||||
return UOp, tuple(args)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
|
||||
@@ -690,7 +690,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.STACK: return self.src[i].sintify()
|
||||
case _: raise RuntimeError(f"no sgep on {self.op}")
|
||||
|
||||
@functools.cached_property
|
||||
# cached property here makes external_uop_gc fail, why?
|
||||
@property
|
||||
def as_shape(self) -> tuple[sint, ...]:
|
||||
if self.op is Ops.CONST: return (self.arg,)*self.dtype.count # NOTE: this will break
|
||||
if self.op is not Ops.STACK: return (ssimplify(self),)
|
||||
@@ -866,6 +867,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def realized(self) -> Buffer|MultiBuffer|None:
|
||||
# only these can be realized
|
||||
if self.op not in (Ops.BUFFER, Ops.MSTACK): return None
|
||||
if self.op is Ops.BUFFER and isinstance(self.arg, ParamArg): return None
|
||||
# LUNIQUEs are never realized
|
||||
if self.op_in_backward_slice_with_self(Ops.LUNIQUE): return None
|
||||
# NOTE: this is used by the JIT to determine which inputs we capture
|
||||
|
||||
Reference in New Issue
Block a user