diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 8655019a60..46c7c760c8 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -334,5 +334,19 @@ class TestBidirectional(unittest.TestCase): graph_rewrite(c, pm, ctx=ctx_list, bpm=bpm) self.assertListEqual(ctx_list, [('+', True), (1, True), (1, False), (2, True), (2, False), ('+', False)]) +class TestStopEarly(unittest.TestCase): + def test_stop_early(self): + a = UOp.const(dtypes.int, 3) + b = UOp.const(dtypes.int, 4) + c = a+b + cn = UOp.const(dtypes.int, 7) + d = UOp.const(dtypes.int, 2) + def visit_const(c:UOp): + print(f"visit {c.arg}") + assert c.arg not in (3,4) + pm_cvisit = PatternMatcher([(UPat(Ops.CONST, name="c"), visit_const),]) + ret = (c+d).substitute({c:cn}, extra_pm=pm_cvisit) + assert ret == cn+d + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 6cc2612d7b..408c11e578 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -9,6 +9,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner from tinygrad.nn.state import get_parameters +from tinygrad.schedule.rangeify import mop_cleanup from dataclasses import dataclass from weakref import WeakKeyDictionary @@ -224,7 +225,7 @@ def _prepare_jit_inputs(args, kwargs): input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb] for lb in lbs if lb.base.realized is not None]) assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" - st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs] + st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs] _var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) var_vals = {k.expr:v for k,v in _var_vals.items()} st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device] diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 90db992bca..f324aae6cf 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType from tinygrad.uop.symbolic import symbolic_simple -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP, Metadata +from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.opt import Opt from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op @@ -13,6 +13,12 @@ from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTI import sys sys.setrecursionlimit(10000) +# movement op on INDEX as a PatternMatcher +pm_mops = PatternMatcher([ + (UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), + lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore +]) + # ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff @@ -31,7 +37,12 @@ def split_reduceop(reduce:UOp, x:UOp): # ~2**10 should be enough if GROUP is used # 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum. # split is moved to the end to provide maximum locality for the second phase reduce. - is_expanded = unwrap(x.st).is_expanded() + + # get expanded by rangeifying the UOp x + indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)]) + range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges] + is_expanded = [i not in range_nums for i in range(len(x.shape))] + if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1) if x.shape[i]%d==0 and not is_expanded[i]]): return None dim_to_split, divisor = split_candidates[0] @@ -41,13 +52,15 @@ def split_reduceop(reduce:UOp, x:UOp): # reduce original axes, then split return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape).replace(tag=reduce.tag) -earliest_rewrites = PatternMatcher([ - # just removing it works... - (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]), - +mop_cleanup = PatternMatcher([ # merge adjacent RESHAPES, safe because they are not tagged (UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None), +]) + +earliest_rewrites = mop_cleanup+PatternMatcher([ + # just removing it works... + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]), # remove CONTIGUOUS if the BUFFER is already contiguous (UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)), @@ -95,15 +108,6 @@ earliest_rewrites = PatternMatcher([ (UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes), ]) -# ***************** -# 3a. rangeify (movement) - -# movement op on INDEX as a PatternMatcher -pm_mops = PatternMatcher([ - (UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), - lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore -]) - # ***************** # 3.5 cleanups diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index fe98b4af67..83bba337fa 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -351,11 +351,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def __bool__(self): return self._eval((dtypes.bool,), bool) def __int__(self): return self._eval(dtypes.ints, int) def __float__(self): return self._eval(dtypes.floats, float) - def substitute(self, dvars:dict[UOp, UOp], name:str|None=None): + def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None): dvars = {k:v for k,v in dvars.items() if k is not v} if len(dvars) == 0: return self with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)): - return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name) + return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name) # *** uop tracing stuff *** @@ -647,6 +647,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def unbind(self) -> tuple[Variable, int]: assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}" return self.src[0], self.src[1].arg + def unbind_all(self) -> tuple[UOp, dict[Variable, int]]: + ret:dict[Variable, int] = {} + return graph_rewrite(self, pm_unbind, ctx=ret), ret @property def val(self) -> int: return self.unbind()[1] def vars(self) -> set[UOp]: @@ -1220,6 +1223,12 @@ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) +def do_unbind(ctx:dict[Variable, int], x:UOp): + v,i = x.unbind() + ctx[v] = i + return v +pm_unbind = PatternMatcher([(UPat(Ops.BIND, name="x"), do_unbind)]) + # for debug syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}