new stack can compile

This commit is contained in:
George Hotz
2026-04-22 14:04:18 +08:00
parent 7c585be215
commit 3412382cae
6 changed files with 59 additions and 12 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.expander2 import expander2, expander_broadcast
from tinygrad.codegen.late.expander2 import expander2, devectorizer2
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts
@@ -75,10 +75,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
# devectorize (TODO: does this need opts?)
"""
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
"""
sink = graph_rewrite(sink, devectorizer2, name="devectorize2")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")

View File

@@ -1,6 +1,8 @@
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, AxisType, UOp, GroupOp, _align_left, _broadcast_shape
from tinygrad.dtype import dtypes
from tinygrad.helpers import all_same
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.schedule.rangeify import pm_index_mops
def build_range_map(ctx, sink:UOp):
for x in sink.toposort():
@@ -14,16 +16,44 @@ expander2 = PatternMatcher([
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
])+pm_flatten_range
# *** unused broadcasting, it's just in shape now ***
def broadcast_binary(x:UOp):
shapes = [u.shape for u in x.src]
print(x.op, shapes)
if all_same(shapes): return None
shaped_aligned = _align_left(*shapes)
broadcasted = _broadcast_shape(*shapes)
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
return x.replace(src=tuple(src_reshaped))
expander_broadcast = PatternMatcher([
(UPat(GroupOp.Binary|GroupOp.Ternary, name="x"), broadcast_binary),
])
def do_binary_devectorize(b:UOp):
if b.shape == (): return None
# broadcasting needs to be already unpacked
if not all_same([x.shape for x in b.src]): return None
assert len(b.shape) == 1
src = []
for i in range(b.shape[0]):
src.append(b.replace(src=tuple([x.index(UOp.const(dtypes.weakint, i)) for x in b.src])))
return UOp.cat(*src)
devectorizer2 = pm_index_mops+PatternMatcher([
# INDEX with one src is a noop
(UPat(Ops.INDEX, src=(UPat.var("x"),)), lambda x: x),
# INDEX into VCONST is CONST
(UPat(Ops.INDEX, src=(UPat(Ops.VCONST, name="a"), UPat.cvar("i", vec=False))),
lambda a,i: UOp.const(a.dtype, a.arg[i.arg])),
# INDEX into CAT is src
(UPat(Ops.INDEX, src=(UPat(Ops.CAT, name="a"), UPat.cvar("i", vec=False))),
lambda a,i: a.src[i.arg] if a.arg == -1 else None),
# cat goes through index
(UPat(Ops.INDEX, src=(UPat.var("a"), UPat(Ops.CAT, name="c"))),
lambda a,c: UOp.cat(*[a.index(x) for x in c.src])),
# cat on store is group (TODO: do we need group?)
(UPat(Ops.CAT, src=UPat(Ops.STORE), name="x"), lambda x: UOp.group(*x.src)),
# unpack broadcasting
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
(UPat(GroupOp.Binary|{Ops.STORE}, name="b"), do_binary_devectorize),
])

View File

@@ -63,12 +63,15 @@ pm_fold_moved_after = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
# movement op on INDEX as a PatternMatcher
# TODO: clean up .src[0]._shape is not None
pm_mops = PatternMatcher([
pm_index_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
])
# movement op on INDEX as a PatternMatcher
# TODO: clean up .src[0]._shape is not None
pm_mops = pm_index_mops+PatternMatcher([
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),

View File

@@ -100,6 +100,7 @@ class Ops(FastEnum):
# the core 6 movement ops! these only exist in the tensor graph
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
MULTI = auto() # MULTI is really a movement op
CAT = auto() # see CAT in spec
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()

View File

@@ -211,7 +211,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def _shape(self) -> tuple[sint, ...]|None:
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None
@@ -244,6 +244,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# pointer index
#return self.src[0].shape[len(self.src[1:]):]
case Ops.CAT:
if self.arg == -1:
assert all_same([x.shape for x in self.src])
return (len(self.src),)+self.src[0].shape
# TODO: write the non arg=-1 path
# some ops init the shape
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
case Ops.VCONST: return (len(self.arg),)
@@ -315,7 +321,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
# broadcasting here
if self.op in GroupOp.Binary|GroupOp.Ternary:
# TODO: STORE can only broadcast a smaller src[1] into a larger src[0]
if self.op in GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}:
return _broadcast_shape(*[u.shape for u in self.src])
# elementwise ops keep the shape the same. all inputs with shape must match
@@ -417,6 +424,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# *** uop syntactic sugar ***
def cat(*srcs:UOp, axis=-1): # pylint: disable=no-self-argument
assert len(srcs) >= 1 and all_same([x.dtype for x in srcs])
return UOp(Ops.CAT, srcs[0].dtype, src=tuple(srcs), arg=axis)
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument

View File

@@ -47,7 +47,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B",
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", Ops.CAT: "#D8F9E4",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
@@ -115,7 +115,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
#if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")