mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
This reverts commit 996152d2de.
This commit is contained in:
@@ -127,15 +127,10 @@ transcendental_patterns = [
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
|
||||
]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_transcendental_patterns(ops, force_transcendental=False):
|
||||
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.lru_cache(None)
|
||||
def get_extra_patterns(ops):
|
||||
pat: List[Tuple[UPat, Callable]] = []
|
||||
def get_extra_patterns(ops, force_transcendental=False):
|
||||
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
||||
if BinaryOps.AND in ops:
|
||||
pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
@@ -457,23 +452,13 @@ devectorize = PatternMatcher([
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store),
|
||||
])
|
||||
|
||||
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, store_gate:UOp) -> Optional[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def find_gate(x:UOp) -> Optional[UOp]:
|
||||
if x.op is UOps.IF: return x
|
||||
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
|
||||
if (gate:=find_gate(store)) is None or gate.src[0] is not store_gate: return None
|
||||
return UOp.store(buf.index(idx), *store.src[1:])
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
reducer = PatternMatcher([
|
||||
# late fixup of unfoldable image loads
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
# simplify valid
|
||||
(UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")),), allow_any_len=True, name="store"), delete_redundant_gates),
|
||||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
@@ -496,39 +481,55 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp
|
||||
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
||||
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
# renderers can't deal with VCONST or multiGEP
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
# basic sym rule, don't vectorize size one
|
||||
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def find_gate(x:UOp) -> Optional[UOp]:
|
||||
if x.op is UOps.IF: return x
|
||||
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
|
||||
if len(root.src) == 2 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[2]: return None
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:2], root.arg)
|
||||
|
||||
finalize = PatternMatcher([
|
||||
# move masks of loads/stores
|
||||
# TODO: this should be an IF instead of a masked STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
|
||||
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
])
|
||||
|
||||
# for rendering, we don't use vector
|
||||
pm_render = PatternMatcher([
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# initial symbolic + migrate indexing (remove this) + early transcendental
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
|
||||
# temp for indexing migration
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing)
|
||||
|
||||
# convert EXPAND -> VECTORIZE
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual, belongs in lowerer)
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
|
||||
sink = graph_rewrite(sink, sym+just_reduce, ctx=[0])
|
||||
|
||||
# devectorize + load/store indexing
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
|
||||
# devectorize
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, pm_render+get_extra_patterns(supported_ops)+extra_matcher)
|
||||
# cleanups
|
||||
sink = graph_rewrite(sink, sym+reducer)
|
||||
|
||||
# finalize
|
||||
sink = graph_rewrite(sink, sym+finalize+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
# for rendering without sym (including the rules from the renderer)
|
||||
sink = graph_rewrite(sink, (pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
|
||||
return sink
|
||||
|
||||
Reference in New Issue
Block a user