From 3412382caee80ec79be186bdfac9a9f793dffde8 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 22 Apr 2026 14:04:18 +0800 Subject: [PATCH] new stack can compile --- tinygrad/codegen/__init__.py | 5 +++- tinygrad/codegen/late/expander2.py | 38 ++++++++++++++++++++++++++---- tinygrad/schedule/rangeify.py | 9 ++++--- tinygrad/uop/__init__.py | 1 + tinygrad/uop/ops.py | 14 +++++++++-- tinygrad/viz/serve.py | 4 ++-- 6 files changed, 59 insertions(+), 12 deletions(-) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 71d4cb8bfe..24135d0f42 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -12,7 +12,7 @@ from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce -from tinygrad.codegen.late.expander2 import expander2, expander_broadcast +from tinygrad.codegen.late.expander2 import expander2, devectorizer2 from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images from tinygrad.codegen.opt.postrange import apply_opts @@ -75,10 +75,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True) # devectorize (TODO: does this need opts?) + """ if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize") + """ + sink = graph_rewrite(sink, devectorizer2, name="devectorize2") # lower the index dtype to a concrete int sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes") diff --git a/tinygrad/codegen/late/expander2.py b/tinygrad/codegen/late/expander2.py index f417bd3b79..62c9a8c699 100644 --- a/tinygrad/codegen/late/expander2.py +++ b/tinygrad/codegen/late/expander2.py @@ -1,6 +1,8 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, AxisType, UOp, GroupOp, _align_left, _broadcast_shape +from tinygrad.dtype import dtypes from tinygrad.helpers import all_same from tinygrad.codegen.simplify import pm_flatten_range +from tinygrad.schedule.rangeify import pm_index_mops def build_range_map(ctx, sink:UOp): for x in sink.toposort(): @@ -14,16 +16,44 @@ expander2 = PatternMatcher([ .reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None), ])+pm_flatten_range -# *** unused broadcasting, it's just in shape now *** def broadcast_binary(x:UOp): shapes = [u.shape for u in x.src] + print(x.op, shapes) if all_same(shapes): return None shaped_aligned = _align_left(*shapes) broadcasted = _broadcast_shape(*shapes) src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)] return x.replace(src=tuple(src_reshaped)) -expander_broadcast = PatternMatcher([ - (UPat(GroupOp.Binary|GroupOp.Ternary, name="x"), broadcast_binary), -]) \ No newline at end of file +def do_binary_devectorize(b:UOp): + if b.shape == (): return None + # broadcasting needs to be already unpacked + if not all_same([x.shape for x in b.src]): return None + assert len(b.shape) == 1 + src = [] + for i in range(b.shape[0]): + src.append(b.replace(src=tuple([x.index(UOp.const(dtypes.weakint, i)) for x in b.src]))) + return UOp.cat(*src) + +devectorizer2 = pm_index_mops+PatternMatcher([ + # INDEX with one src is a noop + (UPat(Ops.INDEX, src=(UPat.var("x"),)), lambda x: x), + # INDEX into VCONST is CONST + (UPat(Ops.INDEX, src=(UPat(Ops.VCONST, name="a"), UPat.cvar("i", vec=False))), + lambda a,i: UOp.const(a.dtype, a.arg[i.arg])), + # INDEX into CAT is src + (UPat(Ops.INDEX, src=(UPat(Ops.CAT, name="a"), UPat.cvar("i", vec=False))), + lambda a,i: a.src[i.arg] if a.arg == -1 else None), + + # cat goes through index + (UPat(Ops.INDEX, src=(UPat.var("a"), UPat(Ops.CAT, name="c"))), + lambda a,c: UOp.cat(*[a.index(x) for x in c.src])), + + # cat on store is group (TODO: do we need group?) + (UPat(Ops.CAT, src=UPat(Ops.STORE), name="x"), lambda x: UOp.group(*x.src)), + + # unpack broadcasting + (UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary), + (UPat(GroupOp.Binary|{Ops.STORE}, name="b"), do_binary_devectorize), +]) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 03717ba1a0..05e21c939e 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -63,12 +63,15 @@ pm_fold_moved_after = PatternMatcher([ (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) -# movement op on INDEX as a PatternMatcher -# TODO: clean up .src[0]._shape is not None -pm_mops = PatternMatcher([ +pm_index_mops = PatternMatcher([ (UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg) if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None), +]) + +# movement op on INDEX as a PatternMatcher +# TODO: clean up .src[0]._shape is not None +pm_mops = pm_index_mops+PatternMatcher([ # move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after) (UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True), lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)), diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index e93bc43eaf..42a2c4a545 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -100,6 +100,7 @@ class Ops(FastEnum): # the core 6 movement ops! these only exist in the tensor graph RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() MULTI = auto() # MULTI is really a movement op + CAT = auto() # see CAT in spec # reduce REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 211e30b57a..bfd7fa4735 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -211,7 +211,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def _shape(self) -> tuple[sint, ...]|None: match self.op: # late ops don't have shape - case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ + case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ Ops.VECTORIZE | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \ Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION: return None @@ -244,6 +244,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # pointer index #return self.src[0].shape[len(self.src[1:]):] + case Ops.CAT: + if self.arg == -1: + assert all_same([x.shape for x in self.src]) + return (len(self.src),)+self.src[0].shape + # TODO: write the non arg=-1 path + # some ops init the shape case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return () case Ops.VCONST: return (len(self.arg),) @@ -315,7 +321,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return tuple(1 if i in axis_arg else s for i,s in enumerate(ps)) # broadcasting here - if self.op in GroupOp.Binary|GroupOp.Ternary: + # TODO: STORE can only broadcast a smaller src[1] into a larger src[0] + if self.op in GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}: return _broadcast_shape(*[u.shape for u in self.src]) # elementwise ops keep the shape the same. all inputs with shape must match @@ -417,6 +424,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # *** uop syntactic sugar *** + def cat(*srcs:UOp, axis=-1): # pylint: disable=no-self-argument + assert len(srcs) >= 1 and all_same([x.dtype for x in srcs]) + return UOp(Ops.CAT, srcs[0].dtype, src=tuple(srcs), arg=axis) def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs) def maketuple(*srcs:UOp): # pylint: disable=no-self-argument diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 4978423cfd..95508dc6e6 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -47,7 +47,7 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B", Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", + Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", Ops.CAT: "#D8F9E4", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6", Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040", @@ -115,7 +115,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]: if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u) if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u) # exclude RESHAPE/EXPAND that only serve to broadcast a CONST - if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u) + #if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u) for u in toposort: if u in excluded: continue argst = codecs.decode(str(u.arg), "unicode_escape")