Revert "improve full_graph_rewrite matchers for speed (#7431)" (#7434)

This reverts commit 996152d2de.
This commit is contained in:
George Hotz
2024-10-31 15:16:47 +07:00
committed by GitHub
parent 996152d2de
commit 2e3048fc57
3 changed files with 50 additions and 47 deletions

View File

@@ -4,9 +4,8 @@ from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, load_store_indexing, sym, float4_folding, migrate_indexing
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding, finalize, migrate_indexing
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.shape.shapetracker import ShapeTracker, View
@@ -446,10 +445,12 @@ class TestUOpGraph(unittest.TestCase):
def expander_rewrite(sink):
sink = graph_rewrite(sink, sym + expander)
return graph_rewrite(sink, sym + load_store_indexing)
sink = graph_rewrite(sink, sym + reducer)
return graph_rewrite(sink, sym + finalize)
def float4_rewrite(sink):
sink = graph_rewrite(sink, sym + migrate_indexing)
return graph_rewrite(sink, sym + expander + float4_folding)
sink = graph_rewrite(sink, sym + expander + float4_folding)
return graph_rewrite(sink, sym + finalize)
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
@@ -617,11 +618,11 @@ class TestLoadStoreFolder(unittest.TestCase):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)).sink()
sink = full_graph_rewrite(sink, Renderer())
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
self.assertEqual(single_load.src[1].op, UOps.CONST)
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
@@ -636,7 +637,7 @@ class TestLoadStoreFolder(unittest.TestCase):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = full_graph_rewrite(sink, Renderer())
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
def test_simple_store_fold_gate(self):
@@ -644,7 +645,7 @@ class TestLoadStoreFolder(unittest.TestCase):
gate = UOp.variable("g1", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = full_graph_rewrite(sink, Renderer())
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
assert len(one_store.src) == 3
@@ -656,7 +657,8 @@ class TestLoadStoreFolder(unittest.TestCase):
gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = full_graph_rewrite(sink, Renderer())
sink = float4_rewrite(sink)
print(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
class TestIFUOps(unittest.TestCase):

View File

@@ -127,15 +127,10 @@ transcendental_patterns = [
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
]
@functools.lru_cache(None)
def get_transcendental_patterns(ops, force_transcendental=False):
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
return PatternMatcher(pat)
powers_of_two = {2**i:i for i in range(64)}
@functools.lru_cache(None)
def get_extra_patterns(ops):
pat: List[Tuple[UPat, Callable]] = []
def get_extra_patterns(ops, force_transcendental=False):
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
if BinaryOps.AND in ops:
pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
@@ -457,23 +452,13 @@ devectorize = PatternMatcher([
(UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store),
])
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, store_gate:UOp) -> Optional[UOp]:
@functools.lru_cache(None)
def find_gate(x:UOp) -> Optional[UOp]:
if x.op is UOps.IF: return x
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
if (gate:=find_gate(store)) is None or gate.src[0] is not store_gate: return None
return UOp.store(buf.index(idx), *store.src[1:])
load_store_indexing = PatternMatcher([
reducer = PatternMatcher([
# late fixup of unfoldable image loads
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
# simplify valid
(UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
# image load valid idx simplification
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# delete_redundant_gates (after expand)
(UPat(UOps.STORE, src=(UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")),), allow_any_len=True, name="store"), delete_redundant_gates),
])
def idx_load_store(x:UOp):
@@ -496,39 +481,55 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
pm_render = PatternMatcher([
# renderers can't deal with VCONST or multiGEP
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
# basic sym rule, don't vectorize size one
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
@functools.lru_cache(None)
def find_gate(x:UOp) -> Optional[UOp]:
if x.op is UOps.IF: return x
return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None)
if len(root.src) == 2 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[2]: return None
return UOp(UOps.STORE, root.dtype, root.src[:2], root.arg)
finalize = PatternMatcher([
# move masks of loads/stores
# TODO: this should be an IF instead of a masked STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
# delete_redundant_gates (after expand)
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
])
# for rendering, we don't use vector
pm_render = PatternMatcher([
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
])
# *** uop graph ***
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
# initial symbolic + migrate indexing (remove this) + early transcendental
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
# temp for indexing migration
sink = graph_rewrite(sink, sym+migrate_indexing)
# convert EXPAND -> VECTORIZE
# expand
sink = graph_rewrite(sink, sym+expander)
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual, belongs in lowerer)
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
sink = graph_rewrite(sink, sym+just_reduce, ctx=[0])
# devectorize + load/store indexing
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
# devectorize
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
# final rules for the renderer (without sym)
sink = graph_rewrite(sink, pm_render+get_extra_patterns(supported_ops)+extra_matcher)
# cleanups
sink = graph_rewrite(sink, sym+reducer)
# finalize
sink = graph_rewrite(sink, sym+finalize+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
# for rendering without sym (including the rules from the renderer)
sink = graph_rewrite(sink, (pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
return sink

View File

@@ -47,7 +47,7 @@ ptx_matcher = symbolic+PatternMatcher([
(UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
(UPat(UOps.INDEX, name="x"), lambda x: x.src[0].cast(dtypes.int64) + x.src[1].cast(dtypes.int64)*x.src[0].dtype.itemsize),
(UPat(UOps.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
])