remove st from jit/split_reduceop (#12713)

* remove st from jit

* fix by merging reshapes

* no st usage in rangeify

* hmm, stop early works

* fix speed regressions
This commit is contained in:
George Hotz
2025-10-16 12:50:58 +08:00
committed by GitHub
parent 069177c1be
commit 7c19db00f1
4 changed files with 46 additions and 18 deletions

View File

@@ -334,5 +334,19 @@ class TestBidirectional(unittest.TestCase):
graph_rewrite(c, pm, ctx=ctx_list, bpm=bpm)
self.assertListEqual(ctx_list, [('+', True), (1, True), (1, False), (2, True), (2, False), ('+', False)])
class TestStopEarly(unittest.TestCase):
def test_stop_early(self):
a = UOp.const(dtypes.int, 3)
b = UOp.const(dtypes.int, 4)
c = a+b
cn = UOp.const(dtypes.int, 7)
d = UOp.const(dtypes.int, 2)
def visit_const(c:UOp):
print(f"visit {c.arg}")
assert c.arg not in (3,4)
pm_cvisit = PatternMatcher([(UPat(Ops.CONST, name="c"), visit_const),])
ret = (c+d).substitute({c:cn}, extra_pm=pm_cvisit)
assert ret == cn+d
if __name__ == '__main__':
unittest.main()

View File

@@ -9,6 +9,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
from tinygrad.schedule.rangeify import mop_cleanup
from dataclasses import dataclass
from weakref import WeakKeyDictionary
@@ -224,7 +225,7 @@ def _prepare_jit_inputs(args, kwargs):
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
for lb in lbs if lb.base.realized is not None])
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs]
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
var_vals = {k.expr:v for k,v in _var_vals.items()}
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]

View File

@@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP, Metadata
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@@ -13,6 +13,12 @@ from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTI
import sys
sys.setrecursionlimit(10000)
# movement op on INDEX as a PatternMatcher
pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
])
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
@@ -31,7 +37,12 @@ def split_reduceop(reduce:UOp, x:UOp):
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
is_expanded = unwrap(x.st).is_expanded()
# get expanded by rangeifying the UOp x
indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)])
range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges]
is_expanded = [i not in range_nums for i in range(len(x.shape))]
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and not is_expanded[i]]): return None
dim_to_split, divisor = split_candidates[0]
@@ -41,13 +52,15 @@ def split_reduceop(reduce:UOp, x:UOp):
# reduce original axes, then split
return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape).replace(tag=reduce.tag)
earliest_rewrites = PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
mop_cleanup = PatternMatcher([
# merge adjacent RESHAPES, safe because they are not tagged
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
])
earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
@@ -95,15 +108,6 @@ earliest_rewrites = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
])
# *****************
# 3a. rangeify (movement)
# movement op on INDEX as a PatternMatcher
pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
])
# *****************
# 3.5 cleanups

View File

@@ -351,11 +351,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)
def __float__(self): return self._eval(dtypes.floats, float)
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None):
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None):
dvars = {k:v for k,v in dvars.items() if k is not v}
if len(dvars) == 0: return self
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name)
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name)
# *** uop tracing stuff ***
@@ -647,6 +647,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def unbind(self) -> tuple[Variable, int]:
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
ret:dict[Variable, int] = {}
return graph_rewrite(self, pm_unbind, ctx=ret), ret
@property
def val(self) -> int: return self.unbind()[1]
def vars(self) -> set[UOp]:
@@ -1220,6 +1223,12 @@ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
def do_unbind(ctx:dict[Variable, int], x:UOp):
v,i = x.unbind()
ctx[v] = i
return v
pm_unbind = PatternMatcher([(UPat(Ops.BIND, name="x"), do_unbind)])
# for debug
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}