mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
new stack can compile
This commit is contained in:
@@ -12,7 +12,7 @@ from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.expander2 import expander2, expander_broadcast
|
||||
from tinygrad.codegen.late.expander2 import expander2, devectorizer2
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
@@ -75,10 +75,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b
|
||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
|
||||
|
||||
# devectorize (TODO: does this need opts?)
|
||||
"""
|
||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
||||
"""
|
||||
sink = graph_rewrite(sink, devectorizer2, name="devectorize2")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, AxisType, UOp, GroupOp, _align_left, _broadcast_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.schedule.rangeify import pm_index_mops
|
||||
|
||||
def build_range_map(ctx, sink:UOp):
|
||||
for x in sink.toposort():
|
||||
@@ -14,16 +16,44 @@ expander2 = PatternMatcher([
|
||||
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
|
||||
])+pm_flatten_range
|
||||
|
||||
# *** unused broadcasting, it's just in shape now ***
|
||||
|
||||
def broadcast_binary(x:UOp):
|
||||
shapes = [u.shape for u in x.src]
|
||||
print(x.op, shapes)
|
||||
if all_same(shapes): return None
|
||||
shaped_aligned = _align_left(*shapes)
|
||||
broadcasted = _broadcast_shape(*shapes)
|
||||
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
|
||||
return x.replace(src=tuple(src_reshaped))
|
||||
|
||||
expander_broadcast = PatternMatcher([
|
||||
(UPat(GroupOp.Binary|GroupOp.Ternary, name="x"), broadcast_binary),
|
||||
])
|
||||
def do_binary_devectorize(b:UOp):
|
||||
if b.shape == (): return None
|
||||
# broadcasting needs to be already unpacked
|
||||
if not all_same([x.shape for x in b.src]): return None
|
||||
assert len(b.shape) == 1
|
||||
src = []
|
||||
for i in range(b.shape[0]):
|
||||
src.append(b.replace(src=tuple([x.index(UOp.const(dtypes.weakint, i)) for x in b.src])))
|
||||
return UOp.cat(*src)
|
||||
|
||||
devectorizer2 = pm_index_mops+PatternMatcher([
|
||||
# INDEX with one src is a noop
|
||||
(UPat(Ops.INDEX, src=(UPat.var("x"),)), lambda x: x),
|
||||
# INDEX into VCONST is CONST
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VCONST, name="a"), UPat.cvar("i", vec=False))),
|
||||
lambda a,i: UOp.const(a.dtype, a.arg[i.arg])),
|
||||
# INDEX into CAT is src
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CAT, name="a"), UPat.cvar("i", vec=False))),
|
||||
lambda a,i: a.src[i.arg] if a.arg == -1 else None),
|
||||
|
||||
# cat goes through index
|
||||
(UPat(Ops.INDEX, src=(UPat.var("a"), UPat(Ops.CAT, name="c"))),
|
||||
lambda a,c: UOp.cat(*[a.index(x) for x in c.src])),
|
||||
|
||||
# cat on store is group (TODO: do we need group?)
|
||||
(UPat(Ops.CAT, src=UPat(Ops.STORE), name="x"), lambda x: UOp.group(*x.src)),
|
||||
|
||||
# unpack broadcasting
|
||||
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
|
||||
(UPat(GroupOp.Binary|{Ops.STORE}, name="b"), do_binary_devectorize),
|
||||
])
|
||||
|
||||
@@ -63,12 +63,15 @@ pm_fold_moved_after = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
# TODO: clean up .src[0]._shape is not None
|
||||
pm_mops = PatternMatcher([
|
||||
pm_index_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
|
||||
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
|
||||
])
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
# TODO: clean up .src[0]._shape is not None
|
||||
pm_mops = pm_index_mops+PatternMatcher([
|
||||
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
|
||||
|
||||
@@ -100,6 +100,7 @@ class Ops(FastEnum):
|
||||
# the core 6 movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
|
||||
MULTI = auto() # MULTI is really a movement op
|
||||
CAT = auto() # see CAT in spec
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||
|
||||
@@ -211,7 +211,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||
return None
|
||||
@@ -244,6 +244,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# pointer index
|
||||
#return self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
case Ops.CAT:
|
||||
if self.arg == -1:
|
||||
assert all_same([x.shape for x in self.src])
|
||||
return (len(self.src),)+self.src[0].shape
|
||||
# TODO: write the non arg=-1 path
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
case Ops.VCONST: return (len(self.arg),)
|
||||
@@ -315,7 +321,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
# broadcasting here
|
||||
if self.op in GroupOp.Binary|GroupOp.Ternary:
|
||||
# TODO: STORE can only broadcast a smaller src[1] into a larger src[0]
|
||||
if self.op in GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}:
|
||||
return _broadcast_shape(*[u.shape for u in self.src])
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
@@ -417,6 +424,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop syntactic sugar ***
|
||||
|
||||
def cat(*srcs:UOp, axis=-1): # pylint: disable=no-self-argument
|
||||
assert len(srcs) >= 1 and all_same([x.dtype for x in srcs])
|
||||
return UOp(Ops.CAT, srcs[0].dtype, src=tuple(srcs), arg=axis)
|
||||
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
|
||||
|
||||
@@ -47,7 +47,7 @@ from tinygrad.dtype import dtypes
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", Ops.CAT: "#D8F9E4",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
@@ -115,7 +115,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
#if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
|
||||
Reference in New Issue
Block a user