From eb1238436ace9969924f774e5c0bdd628335225b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 7 Jun 2026 12:25:11 -0700 Subject: [PATCH] more prereqs for DL/DR -> BUFFER (#16529) --- tinygrad/codegen/__init__.py | 12 +++++++++--- tinygrad/renderer/wgsl.py | 4 ++-- tinygrad/uop/ops.py | 5 +++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 4e3fd2f3e7..ee7c7be49f 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -4,11 +4,12 @@ import itertools from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp +from tinygrad.uop.ops import ParamArg from tinygrad.uop.render import pyrender from tinygrad.uop.spec import type_verify, spec_tensor, spec_program from tinygrad.renderer import Renderer, Estimates from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext -from tinygrad.dtype import dtypes, PtrDType, ImageDType +from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace # import all pattern matchers here from tinygrad.codegen.gpudims import pm_add_gpudims @@ -34,12 +35,17 @@ pm_index_is_shrink = PatternMatcher([ pm_remove_vec_dtypes = PatternMatcher([ # rewrite PARAM to non pointer - (UPat((Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf: + (UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf: buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \ if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None), + # no LOADs on register dtypes + (UPat(Ops.LOAD, name="x"), lambda x: x.src[0] if x.src[0].addrspace == AddrSpace.REG else None), # remove all vec dtypes (UPat(GroupOp.All-{Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"), lambda x: x.replace(dtype=x.dtype.base.scalar().base)), + # replace DEFINE_LOCAL/DEFINE_REG with BUFFER + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x: + x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))), ])+pm_clean_up_group_sink def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: @@ -121,7 +127,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: if ren.new_style: sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") - sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="remove vec dtypes") + sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style") # this was the linearizer sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 256db08421..3117803355 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -110,8 +110,8 @@ class WGSLRenderer(CStyleLanguage): prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" prg += "@group(0) @binding(0)\nvar INFINITY : f32;\n" prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" + - f"{'var' if isinstance(u.dtype, PtrDType) else 'var'}" + - f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs]) + f"{'var' if u.addrspace == AddrSpace.GLOBAL else 'var'}" + + f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if u.addrspace == AddrSpace.GLOBAL else self.buf_map(u.dtype)};" for name,(u,_) in bufs]) prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3," return prg + "@builtin(local_invocation_id) lindex: vec3) {\n" + "\n".join(kernel) + "\n}" diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5e75a07e11..a23a5500f4 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -29,7 +29,7 @@ class ParamArg: device: str|tuple[str, ...]|None = None def __repr__(self): fields = (("vmin_vmax", None), ("name", None), ("addrspace", AddrSpace.GLOBAL), ("axis", None), ("device", None)) - args = [str(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default] + args = [repr(self.slot)] + [f"{k}={v!r}" for k,default in fields if (v:=getattr(self, k)) != default] return f"ParamArg({', '.join(args)})" axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} @@ -278,7 +278,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.GETADDR: return () case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return () case Ops.BINARY: return (len(self.arg),) - case Ops.BUFFER: return (self.arg,) + case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,) case Ops.SLICE: # HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX if self.src[0].op is Ops.INDEX: return () @@ -759,6 +759,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): assert isinstance(self.src[0].device, tuple), f"mselect must be on tuple device, getting {self.src[0].device}" return self.src[0].device[self.arg] if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src) + if self.op is Ops.BUFFER and isinstance(self.arg, ParamArg): return self.arg.device if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device for x in self.src: if x.device is not None: return x.device