mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 01:15:49 +08:00
replace DEFINE_VAR with PARAM (#16576)
* replace DEFINE_VAR with PARAM * cleanups * cleanups
This commit is contained in:
@@ -44,6 +44,9 @@ pm_remove_vec_dtypes = PatternMatcher([
|
||||
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
|
||||
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
|
||||
# replace DEFINE_VAR with PARAM
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x:
|
||||
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=ctx[x.arg[0]], name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
@@ -125,7 +128,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
|
||||
if ren.new_style:
|
||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
||||
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
|
||||
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
@@ -50,7 +50,7 @@ class DTypeMetaClass(type):
|
||||
return ret
|
||||
|
||||
class AddrSpace(IntEnum):
|
||||
def __repr__(self): return str(self)
|
||||
def __repr__(self): return f"{self.__class__.__name__}.{self.name}"
|
||||
GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
|
||||
@@ -39,7 +39,7 @@ class Estimates:
|
||||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
|
||||
elif (u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.arg.addrspace is None)) and u.expr == 'core_id': mults *= int(u.vmax) + 1
|
||||
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.dtype.scalar().itemsize * mults
|
||||
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
|
||||
|
||||
@@ -220,9 +220,8 @@ class CStyleLanguage(Renderer):
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
|
||||
if u.op is not Ops.PARAM: r[u] = u.arg[0]
|
||||
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
|
||||
if u.op is Ops.PARAM:
|
||||
if isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
|
||||
else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}"
|
||||
bufs[u] = (r[u], (u, u in writable_params))
|
||||
continue
|
||||
|
||||
@@ -148,8 +148,8 @@ class LLVMRenderer(Renderer):
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg.slot}" if u.op is Ops.PARAM else f"%{u.expr}"
|
||||
if u.op is Ops.PARAM:
|
||||
r[u] = f"%data{u.arg.slot}"
|
||||
args.append((r[u], u))
|
||||
elif u.op is Ops.BUFFER:
|
||||
r[u] = f"%{'local' if u.addrspace == AddrSpace.LOCAL else 'reg'}_{str(u.arg.slot)}"
|
||||
|
||||
@@ -24,7 +24,7 @@ class ParamArg:
|
||||
slot: int
|
||||
vmin_vmax: tuple[PyConst, PyConst]|None = None
|
||||
name: str|None = None
|
||||
addrspace: AddrSpace = AddrSpace.GLOBAL
|
||||
addrspace: AddrSpace|None = AddrSpace.GLOBAL
|
||||
axis: int|None = None
|
||||
device: str|tuple[str, ...]|None = None
|
||||
def __repr__(self):
|
||||
@@ -298,7 +298,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||
case Ops.PARAM:
|
||||
if isinstance(self.dtype, ImageDType): return self.dtype.shape
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count)) if len(self.src) >= 1 else None
|
||||
return self.src[0].as_shape if len(self.src) >= 1 else None
|
||||
|
||||
# wmma output shape = accumulator shape (src[2])
|
||||
case Ops.WMMA | Ops.SHAPED_WMMA: return self.src[2]._shape
|
||||
@@ -898,6 +898,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
def expr(self) -> str:
|
||||
if self.op is Ops.PARAM and self.arg.addrspace is None: return unwrap(self.arg.name)
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
return self.arg[0]
|
||||
def bind(self, val:int|UOp):
|
||||
@@ -914,7 +915,8 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def variables(self) -> list[Variable]:
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg)
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)},
|
||||
key=lambda v: v.expr)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
@@ -1022,7 +1024,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||
def _sym_fxn(self):
|
||||
from tinygrad.uop.render import _render_with_splits, renderer_infer
|
||||
sself = self.simplify()
|
||||
varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
|
||||
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)))
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself})
|
||||
lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"]
|
||||
@@ -1138,8 +1140,8 @@ class ProgramInfo:
|
||||
global_size: list[int] = [1, 1, 1]
|
||||
local_size: list[int]|None = [1, 1, 1]
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
||||
if u.op is Ops.PARAM: _globals.append(u.arg.slot)
|
||||
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace is None): _vars.append(u)
|
||||
if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
|
||||
@@ -1147,9 +1149,9 @@ class ProgramInfo:
|
||||
if u.arg[0] == 'i': local_size = None
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
if u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': global_size[0] = u.arg[2] + 1
|
||||
if u.op in (Ops.DEFINE_VAR, Ops.PARAM) and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
|
||||
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
|
||||
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.arg)),
|
||||
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.expr)),
|
||||
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
Reference in New Issue
Block a user