mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
remove st from jit/split_reduceop (#12713)
* remove st from jit * fix by merging reshapes * no st usage in rangeify * hmm, stop early works * fix speed regressions
This commit is contained in:
@@ -334,5 +334,19 @@ class TestBidirectional(unittest.TestCase):
|
||||
graph_rewrite(c, pm, ctx=ctx_list, bpm=bpm)
|
||||
self.assertListEqual(ctx_list, [('+', True), (1, True), (1, False), (2, True), (2, False), ('+', False)])
|
||||
|
||||
class TestStopEarly(unittest.TestCase):
|
||||
def test_stop_early(self):
|
||||
a = UOp.const(dtypes.int, 3)
|
||||
b = UOp.const(dtypes.int, 4)
|
||||
c = a+b
|
||||
cn = UOp.const(dtypes.int, 7)
|
||||
d = UOp.const(dtypes.int, 2)
|
||||
def visit_const(c:UOp):
|
||||
print(f"visit {c.arg}")
|
||||
assert c.arg not in (3,4)
|
||||
pm_cvisit = PatternMatcher([(UPat(Ops.CONST, name="c"), visit_const),])
|
||||
ret = (c+d).substitute({c:cn}, extra_pm=pm_cvisit)
|
||||
assert ret == cn+d
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -9,6 +9,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.schedule.rangeify import mop_cleanup
|
||||
from dataclasses import dataclass
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
@@ -224,7 +225,7 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
||||
for lb in lbs if lb.base.realized is not None])
|
||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
||||
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
|
||||
st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs]
|
||||
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
var_vals = {k.expr:v for k,v in _var_vals.items()}
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP, Metadata
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
@@ -13,6 +13,12 @@ from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTI
|
||||
import sys
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
|
||||
@@ -31,7 +37,12 @@ def split_reduceop(reduce:UOp, x:UOp):
|
||||
# ~2**10 should be enough if GROUP is used
|
||||
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
|
||||
# split is moved to the end to provide maximum locality for the second phase reduce.
|
||||
is_expanded = unwrap(x.st).is_expanded()
|
||||
|
||||
# get expanded by rangeifying the UOp x
|
||||
indexed = x.index(*[UOp.range(s, i) if resolve(s>1) else UOp.const(dtypes.index, 0) for i,s in enumerate(x.shape)])
|
||||
range_nums = [y.arg[0] for y in indexed.substitute({x.base:UOp(Ops.NOOP)}, extra_pm=pm_mops).ranges]
|
||||
is_expanded = [i not in range_nums for i in range(len(x.shape))]
|
||||
|
||||
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
|
||||
if x.shape[i]%d==0 and not is_expanded[i]]): return None
|
||||
dim_to_split, divisor = split_candidates[0]
|
||||
@@ -41,13 +52,15 @@ def split_reduceop(reduce:UOp, x:UOp):
|
||||
# reduce original axes, then split
|
||||
return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape).replace(tag=reduce.tag)
|
||||
|
||||
earliest_rewrites = PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||
|
||||
mop_cleanup = PatternMatcher([
|
||||
# merge adjacent RESHAPES, safe because they are not tagged
|
||||
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
|
||||
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||
])
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
@@ -95,15 +108,6 @@ earliest_rewrites = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 3a. rangeify (movement)
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 3.5 cleanups
|
||||
|
||||
|
||||
@@ -351,11 +351,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None):
|
||||
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None):
|
||||
dvars = {k:v for k,v in dvars.items() if k is not v}
|
||||
if len(dvars) == 0: return self
|
||||
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
|
||||
return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name)
|
||||
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name)
|
||||
|
||||
# *** uop tracing stuff ***
|
||||
|
||||
@@ -647,6 +647,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def unbind(self) -> tuple[Variable, int]:
|
||||
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
||||
return self.src[0], self.src[1].arg
|
||||
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
|
||||
ret:dict[Variable, int] = {}
|
||||
return graph_rewrite(self, pm_unbind, ctx=ret), ret
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def vars(self) -> set[UOp]:
|
||||
@@ -1220,6 +1223,12 @@ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
||||
v,i = x.unbind()
|
||||
ctx[v] = i
|
||||
return v
|
||||
pm_unbind = PatternMatcher([(UPat(Ops.BIND, name="x"), do_unbind)])
|
||||
|
||||
# for debug
|
||||
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
|
||||
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
||||
|
||||
Reference in New Issue
Block a user