replace DEFINE_VAR with PARAM (#16576)

* replace DEFINE_VAR with PARAM

* cleanups

* cleanups
This commit is contained in:
George Hotz
2026-06-11 15:03:20 -07:00
committed by GitHub
parent 5f1e2d3900
commit 587333fddb
6 changed files with 22 additions and 16 deletions

View File

@@ -44,6 +44,9 @@ pm_remove_vec_dtypes = PatternMatcher([
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
# replace DEFINE_VAR with PARAM
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x:
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=ctx[x.arg[0]], name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
])+pm_clean_up_group_sink
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
@@ -125,7 +128,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if ren.new_style:
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)

View File

@@ -50,7 +50,7 @@ class DTypeMetaClass(type):
return ret
class AddrSpace(IntEnum):
def __repr__(self): return str(self)
def __repr__(self): return f"{self.__class__.__name__}.{self.name}"
GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
@dataclass(frozen=True, eq=False)

View File

@@ -39,7 +39,7 @@ class Estimates:
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
elif (u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.arg.addrspace is None)) and u.expr == 'core_id': mults *= int(u.vmax) + 1
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.dtype.scalar().itemsize * mults
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:

View File

@@ -220,9 +220,8 @@ class CStyleLanguage(Renderer):
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
if u.op is not Ops.PARAM: r[u] = u.arg[0]
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
if u.op is Ops.PARAM:
if isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}"
bufs[u] = (r[u], (u, u in writable_params))
continue

View File

@@ -148,8 +148,8 @@ class LLVMRenderer(Renderer):
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg.slot}" if u.op is Ops.PARAM else f"%{u.expr}"
if u.op is Ops.PARAM:
r[u] = f"%data{u.arg.slot}"
args.append((r[u], u))
elif u.op is Ops.BUFFER:
r[u] = f"%{'local' if u.addrspace == AddrSpace.LOCAL else 'reg'}_{str(u.arg.slot)}"

View File

@@ -24,7 +24,7 @@ class ParamArg:
slot: int
vmin_vmax: tuple[PyConst, PyConst]|None = None
name: str|None = None
addrspace: AddrSpace = AddrSpace.GLOBAL
addrspace: AddrSpace|None = AddrSpace.GLOBAL
axis: int|None = None
device: str|tuple[str, ...]|None = None
def __repr__(self):
@@ -298,7 +298,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
case Ops.PARAM:
if isinstance(self.dtype, ImageDType): return self.dtype.shape
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count)) if len(self.src) >= 1 else None
return self.src[0].as_shape if len(self.src) >= 1 else None
# wmma output shape = accumulator shape (src[2])
case Ops.WMMA | Ops.SHAPED_WMMA: return self.src[2]._shape
@@ -898,6 +898,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self) -> str:
if self.op is Ops.PARAM and self.arg.addrspace is None: return unwrap(self.arg.name)
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int|UOp):
@@ -914,7 +915,8 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
@property
def val(self) -> int: return self.unbind()[1]
def variables(self) -> list[Variable]:
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg)
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)},
key=lambda v: v.expr)
# *** uop symbolic stuff ***
@@ -1022,7 +1024,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def _sym_fxn(self):
from tinygrad.uop.render import _render_with_splits, renderer_infer
sself = self.simplify()
varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)))
# TODO: sanitize varnames, or don't use naked eval while staying fast
ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself})
lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"]
@@ -1138,8 +1140,8 @@ class ProgramInfo:
global_size: list[int] = [1, 1, 1]
local_size: list[int]|None = [1, 1, 1]
for u in sink.toposort():
if u.op is Ops.DEFINE_VAR: _vars.append(u)
if u.op is Ops.PARAM: _globals.append(u.arg.slot)
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace is None): _vars.append(u)
if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
@@ -1147,9 +1149,9 @@ class ProgramInfo:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
if u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': global_size[0] = u.arg[2] + 1
if u.op in (Ops.DEFINE_VAR, Ops.PARAM) and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.arg)),
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.expr)),
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)
@dataclass(frozen=True)