From 587333fddb87019e1f9df424a594d1e426b022f4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:03:20 -0700 Subject: [PATCH] replace DEFINE_VAR with PARAM (#16576) * replace DEFINE_VAR with PARAM * cleanups * cleanups --- tinygrad/codegen/__init__.py | 7 ++++++- tinygrad/dtype.py | 2 +- tinygrad/renderer/__init__.py | 2 +- tinygrad/renderer/cstyle.py | 5 ++--- tinygrad/renderer/llvmir.py | 4 ++-- tinygrad/uop/ops.py | 18 ++++++++++-------- 6 files changed, 22 insertions(+), 16 deletions(-) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 62dbcf0d41..1d8f972128 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -44,6 +44,9 @@ pm_remove_vec_dtypes = PatternMatcher([ # replace DEFINE_LOCAL/DEFINE_REG with BUFFER (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x: x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))), + # replace DEFINE_VAR with PARAM + (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: + x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=ctx[x.arg[0]], name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))), ])+pm_clean_up_group_sink def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: @@ -125,7 +128,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: if ren.new_style: sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") - sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style") + num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM]) + name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))} + sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style") # this was the linearizer sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True) diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 3c85b318ab..205e5c67c6 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -50,7 +50,7 @@ class DTypeMetaClass(type): return ret class AddrSpace(IntEnum): - def __repr__(self): return str(self) + def __repr__(self): return f"{self.__class__.__name__}.{self.name}" GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702 @dataclass(frozen=True, eq=False) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 9ff02b86db..5fdcd47974 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -39,7 +39,7 @@ class Estimates: mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults elif u.op is Ops.END: mults = mult_stack.pop(-1) elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these - elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1 + elif (u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.arg.addrspace is None)) and u.expr == 'core_id': mults *= int(u.vmax) + 1 elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG: lds += u.max_numel() * u.dtype.scalar().itemsize * mults elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 67ade32869..3b64361eb6 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -220,9 +220,8 @@ class CStyleLanguage(Renderer): if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue - if u.op in (Ops.PARAM, Ops.DEFINE_VAR): - if u.op is not Ops.PARAM: r[u] = u.arg[0] - elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}" + if u.op is Ops.PARAM: + if isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}" else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}" bufs[u] = (r[u], (u, u in writable_params)) continue diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index f12dccd54a..c404861695 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -148,8 +148,8 @@ class LLVMRenderer(Renderer): if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue - if u.op in (Ops.PARAM, Ops.DEFINE_VAR): - r[u] = f"%data{u.arg.slot}" if u.op is Ops.PARAM else f"%{u.expr}" + if u.op is Ops.PARAM: + r[u] = f"%data{u.arg.slot}" args.append((r[u], u)) elif u.op is Ops.BUFFER: r[u] = f"%{'local' if u.addrspace == AddrSpace.LOCAL else 'reg'}_{str(u.arg.slot)}" diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index f113604252..8d65b1a0f5 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -24,7 +24,7 @@ class ParamArg: slot: int vmin_vmax: tuple[PyConst, PyConst]|None = None name: str|None = None - addrspace: AddrSpace = AddrSpace.GLOBAL + addrspace: AddrSpace|None = AddrSpace.GLOBAL axis: int|None = None device: str|tuple[str, ...]|None = None def __repr__(self): @@ -298,7 +298,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass): case Ops.PARAM: if isinstance(self.dtype, ImageDType): return self.dtype.shape if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,) - return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count)) if len(self.src) >= 1 else None + return self.src[0].as_shape if len(self.src) >= 1 else None # wmma output shape = accumulator shape (src[2]) case Ops.WMMA | Ops.SHAPED_WMMA: return self.src[2]._shape @@ -898,6 +898,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass): return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property def expr(self) -> str: + if self.op is Ops.PARAM and self.arg.addrspace is None: return unwrap(self.arg.name) assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" return self.arg[0] def bind(self, val:int|UOp): @@ -914,7 +915,8 @@ class UOp(RandMixin, metaclass=UOpMetaClass): @property def val(self) -> int: return self.unbind()[1] def variables(self) -> list[Variable]: - return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg) + return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)}, + key=lambda v: v.expr) # *** uop symbolic stuff *** @@ -1022,7 +1024,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass): def _sym_fxn(self): from tinygrad.uop.render import _render_with_splits, renderer_infer sself = self.simplify() - varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR) + varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None))) # TODO: sanitize varnames, or don't use naked eval while staying fast ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself}) lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"] @@ -1138,8 +1140,8 @@ class ProgramInfo: global_size: list[int] = [1, 1, 1] local_size: list[int]|None = [1, 1, 1] for u in sink.toposort(): - if u.op is Ops.DEFINE_VAR: _vars.append(u) - if u.op is Ops.PARAM: _globals.append(u.arg.slot) + if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace is None): _vars.append(u) + if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot) if u.op in (Ops.STORE, Ops.LOAD): if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX): if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot) @@ -1147,9 +1149,9 @@ class ProgramInfo: if u.arg[0] == 'i': local_size = None special_size = local_size if u.arg[0] == 'l' else global_size if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) - if u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': global_size[0] = u.arg[2] + 1 + if u.op in (Ops.DEFINE_VAR, Ops.PARAM) and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1 return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size), - tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.arg)), + tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.expr)), tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux) @dataclass(frozen=True)