param args aren't realized

This commit is contained in:
George Hotz
2026-06-06 10:42:29 -07:00
parent 5b2cd04789
commit f011cff4e2

View File

@@ -29,7 +29,7 @@ class ParamArg:
device: str|tuple[str, ...]|None = None
def __repr__(self):
fields = (("vmin_vmax", None), ("name", None), ("addrspace", AddrSpace.GLOBAL), ("axis", None), ("device", None))
args = [str(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default]
args = [repr(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default]
return f"ParamArg({', '.join(args)})"
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
@@ -150,7 +150,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
except AttributeError: pass
def __reduce__(self):
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
if self.op is Ops.BUFFER and not isinstance(self.arg, ParamArg) and self.realized is not None: args.append(self.realized)
if self.op is Ops.BUFFER and self.realized is not None: args.append(self.realized)
return UOp, tuple(args)
def replace(self, **kwargs) -> UOp:
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
@@ -690,7 +690,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.STACK: return self.src[i].sintify()
case _: raise RuntimeError(f"no sgep on {self.op}")
@functools.cached_property
# cached property here makes external_uop_gc fail, why?
@property
def as_shape(self) -> tuple[sint, ...]:
if self.op is Ops.CONST: return (self.arg,)*self.dtype.count # NOTE: this will break
if self.op is not Ops.STACK: return (ssimplify(self),)
@@ -866,6 +867,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def realized(self) -> Buffer|MultiBuffer|None:
# only these can be realized
if self.op not in (Ops.BUFFER, Ops.MSTACK): return None
if self.op is Ops.BUFFER and isinstance(self.arg, ParamArg): return None
# LUNIQUEs are never realized
if self.op_in_backward_slice_with_self(Ops.LUNIQUE): return None
# NOTE: this is used by the JIT to determine which inputs we capture