diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 3a855da1af..806bba0b1e 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -5,16 +5,16 @@ from tinygrad import dtypes, Device from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo -from tinygrad.ops import NOp, PatternMatcher +from tinygrad.ops import UPat, PatternMatcher from tinygrad.codegen.lowerer import ast_to_uop from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding from tinygrad.shape.shapetracker import ShapeTracker, View simple_pm = PatternMatcher([ - (NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), - (NOp.cvar('x') + NOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)), - (NOp.cvar('x') * NOp.cvar('y') * NOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)), - ((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)), + (UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), + (UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)), + (UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)), + ((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)), ]) def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u))) diff --git a/test/test_uops.py b/test/test_uops.py index 9964f877f4..f520135ecb 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -7,7 +7,6 @@ from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device from tinygrad.ops import UOps, UOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401 -from tinygrad.ops import NOp from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel @@ -389,10 +388,6 @@ class TestUOpStr(unittest.TestCase): sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],)) assert_equiv_uops(sink, eval(str(sink))) - def test_nop_str(self): - a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1") - assert str(eval(str(a))) == str(a) - def test_variable_const(self): # TODO: this is not possible after VALID. uop = UOp(UOps.CONST, dtypes.int, (), arg=Variable("a",1,10)) diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 183baa3e70..dedddac9fc 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase): def test_arg(self): matcher = PatternMatcher([ - (UPat(UOps.CONST, 0, name="x"), lambda x: x), - (UPat(UOps.CONST, False, name="x"), lambda x: x), - (UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x), + (UPat(UOps.CONST, arg=0, name="x"), lambda x: x), + (UPat(UOps.CONST, arg=False, name="x"), lambda x: x), + (UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x), ]) c1 = UOp(UOps.CONST, dtypes.float, arg=0.0) c2 = UOp(UOps.CONST, dtypes.bool, arg=False) @@ -47,7 +47,7 @@ class TestPatternMatcher(unittest.TestCase): def test_filter_arg(self): matcher = PatternMatcher([ - (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)], name="x"), + (UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"), lambda x,c: x if c.arg in {1, -1} else None) ]) y1 = UOp(UOps.CONST, dtypes.int, arg=1) diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index a2fb6a3c5b..9be9125535 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -3,7 +3,7 @@ from typing import Tuple, List from tinygrad.dtype import dtypes, DType from tinygrad.ops import UOp -TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64} +TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64) def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp): """replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio""" diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 738d7a18e7..2e042eed62 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element -from tinygrad.ops import NOp, UPat, PatternMatcher, graph_rewrite +from tinygrad.ops import UPat, PatternMatcher, graph_rewrite from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, all_same, partition from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -83,9 +83,9 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp): float4_folding = PatternMatcher([ (UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded), - (UPat({UOps.BARRIER, UOps.SINK}, src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded), + (UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded), (UPat(UOps.VECTORIZE, src=UPat(UOps.REDUCE), name="vec"), vectorize_reduce), - (UPat(UOps.VECTORIZE, src=UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}), name="vec"), vectorize_alu), + (UPat(UOps.VECTORIZE, src=UPat((UOps.ALU, UOps.CAST, UOps.BITCAST)), name="vec"), vectorize_alu), ]) # ***** mod ***** @@ -222,145 +222,147 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce): # this is symbolic 2.0 constant_folder = PatternMatcher([ # bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly - (UPat(UOps.ALU, BinaryOps.ADD, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)), - (UPat(UOps.ALU, BinaryOps.MUL, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)), + (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)), + (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)), # VECTORIZE/GEP - (NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]), - *[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float, - src=(NOp.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in ([2, 4, 8, 16] + ([256] if AMX else []))], - *[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half, - src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]], + (UPat(UOps.GEP, src=(UPat(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]), + *[(UPat(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UPat(UOps.GEP, dtypes.float, + src=(UPat.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in ([2, 4, 8, 16] + ([256] if AMX else []))], + *[(UPat(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UPat(UOps.GEP, dtypes.half, + src=(UPat.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]], # tensor core with a 0 input is acc - *[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))), + *[(UPat(UOps.WMMA, src=(UPat(UOps.VECTORIZE, src=tuple(UPat.const(None, 0.0) for _ in range(i))), UPat.var(), UPat.var('acc'))), lambda acc: acc) for i in [2, 4, 8]], - *[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))), + *[(UPat(UOps.WMMA, src=(UPat.var(), UPat(UOps.VECTORIZE, src=tuple(UPat.const(None, 0.0) for _ in range(i))), UPat.var('acc'))), lambda acc: acc) for i in [2, 4, 8]], # tensor core cleanups - *[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),) + *[(UPat(UOps.REDUCE, src=(UPat(UOps.EXPAND, src=tuple(UPat(UOps.GEP, dtypes.float, src=(UPat.var('x'),), arg=i) for i in range(j)), name="expand"),) ,name="reduce", allow_any_len=True), reduce_before_expand) for j in ([2,4,8] + ([16,256] if AMX else []))], - (NOp.var("add") + NOp(UOps.WMMA, name="wmma"), + (UPat.var("add") + UPat(UOps.WMMA, name="wmma"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), # threefry - (NOp(UOps.ALU, dtype=dtypes.uint64, src=(NOp.var("x"), NOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32), + (UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32), # extra arange loop folding because we don't fold adds. TODO: fold adds - (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") + - NOp.var("idx2") + NOp.var("idx3")) - .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), - (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") + - NOp.var("idx2")) - .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), + (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") + + UPat.var("idx2") + UPat.var("idx3")).lt(UPat.cvar("compval")) + .where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), + (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") + + UPat.var("idx2")).lt(UPat.cvar("compval")) + .where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # arange loop folding (reduce) - (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng")) - .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), + (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng")) + .lt(UPat.cvar("compval")) + .where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # arange loop folding (unrolled) - (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng")) - .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),), + (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng")) + .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # unrolled arange div folding - (NOp.var('divs') + NOp.cvar('c'), fold_unrolled_divs), + (UPat.var('divs') + UPat.cvar('c'), fold_unrolled_divs), # indexing (with a multiply offset)! - (NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()* - NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),), + (UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).cast()* + UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('add')+UPat.var('mul')*UPat(UOps.RANGE, name="rng")), name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), - (NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()* - NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.RANGE, name="rng")), name="ld"),), + (UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).cast()* + UPat(UOps.LOAD, src=(UPat.var("buf"), UPat(UOps.RANGE, name="rng")), name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)), - (NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).where( - NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"), NOp.const(None, 0.0)),), + (UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).where( + UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('add')+UPat.var('mul')*UPat(UOps.RANGE, name="rng")), name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), # max folding - (NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), + (UPat.max(UPat.var('x'), UPat.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # GEP/CAST const rules - (NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)), + (UPat(UOps.GEP, src=(UPat.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)), (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const_like(c.arg)), # a conditional with the same results either way is a noop, also fold const conditionals - (NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val), - (NOp.cvar('gate').where(NOp.var('c0'), NOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1), + (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), + (UPat.cvar('gate').where(UPat.var('c0'), UPat.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1), # ** constant folding ** (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))), # ** self folding ** # cast NOOP (NOTE: it's str to deal with PtrDType) - (NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), - (NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP - (NOp.var('x') + 0, lambda x: x), # x+0 -> x - (NOp.var('x') * 1, lambda x: x), # x*1 -> x - (NOp.var('x') // NOp.var('x'), lambda x: x.const_like(1)), # x//x -> 1 - (NOp.var('x') // 1, lambda x: x), # x//1 -> x - (NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x - (NOp.var('x') / NOp.var('x'), lambda x: x.const_like(1)), # x/x -> 1 - ((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x - (NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c), - (NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x), + (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), + (UPat(UOps.REDUCE, src=(UPat.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP + (UPat.var('x') + 0, lambda x: x), # x+0 -> x + (UPat.var('x') * 1, lambda x: x), # x*1 -> x + (UPat.var('x') // UPat.var('x'), lambda x: x.const_like(1)), # x//x -> 1 + (UPat.var('x') // 1, lambda x: x), # x//1 -> x + (UPat.var('x') // -1, lambda x: -x), # x//-1 -> -x + (UPat.var('x') / UPat.var('x'), lambda x: x.const_like(1)), # x/x -> 1 + ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x + (UPat.var('x', dtype=dtypes.bool) & UPat.cvar('c'), lambda x,c: x if c.arg else c), + (UPat.var('x', dtype=dtypes.bool) | UPat.cvar('c'), lambda x,c: c if c.arg else x), # ** zero folding ** # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN - (NOp.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), + (UPat.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # min==max -> CONST (slow!) - (UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), + (UPat((UOps.ALU, UOps.DEFINE_VAR), name='x'), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # ** load/store folding ** - (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), + (UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.load(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), # ** two stage add/mul folding ** - ((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+(c1+c2)), - ((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*(c1*c2)), + ((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x+(c1+c2)), + ((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)), # *** rules from symbolic *** # ** lt ** # c0*x 0 and c1.arg > 0 else None), # c0*x x2.vmax and x2.vmin >= 0 else None), # generic lt folding - (NOp.var('x').lt(NOp.cvar('c')), + (UPat.var('x').lt(UPat.cvar('c')), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None), # ** div ** # # div folding - (NOp.var('x') // NOp.cvar('c'), lambda x,c: + (UPat.var('x') // UPat.cvar('c'), lambda x,c: newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None), # ** mod ** # mod folding - (NOp.var('x') % NOp.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), + (UPat.var('x') % UPat.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), # mul mod - ((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None), + ((UPat.cvar('c0')*UPat.var('x')) % UPat.cvar('c1'), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None), # ** combine terms ** - (NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x), # (x%c)+(x//c)*c = x - (NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) - (NOp.var("x") + NOp.var("x") * NOp.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) - (NOp.var("x") + NOp.var("x"), lambda x: x*2), # (x+x)-> x*2 - ((NOp.var("x") // NOp.cvar("c0")) // NOp.cvar("c1"), lambda x,c0,c1: x//(c0*c1)), # (x//c0)//c1 -> x//(c0*c1) - ((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) - (-1 * (NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y - ((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0 + (UPat.var('x')%UPat.cvar('c')+(UPat.var('x')//UPat.cvar('c'))*UPat.cvar('c'), lambda x,c: x), # (x%c)+(x//c)*c = x + (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) + (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) + (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 + ((UPat.var("x") // UPat.cvar("c0")) // UPat.cvar("c1"), lambda x,c0,c1: x//(c0*c1)), # (x//c0)//c1 -> x//(c0*c1) + ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) + (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y + ((UPat.cvar("c0") + UPat.var("x")).lt(UPat.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0 # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((NOp.var("x") + NOp.var("y")) * NOp.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None), + ((UPat.var("x") + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None), # x!=0 -> (bool)x - (NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)), + (UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool)), # bitwise noops - ((NOp.var("x") & NOp.var("x")), lambda x: x), - ((NOp.var("x") | NOp.var("x")), lambda x: x), + ((UPat.var("x") & UPat.var("x")), lambda x: x), + ((UPat.var("x") | UPat.var("x")), lambda x: x), # TODO: can do the invert of this (flip alt/load) when we fix double ops - (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("gate").where(NOp.var("alt"), NOp.load(NOp.var("buf"), NOp.var("idx")))), + (UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat.var("buf"), UPat.var("idx")))), lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)), # fold gated LOAD/STORE - (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), - (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")), + (UPat.load(UPat.var("buf"), UPat.var("idx"), UPat.var("var"), UPat.const(dtypes.bool, True)), + lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), + (UPat.load(UPat.var("buf"), UPat.var("idx"), UPat.var("var"), UPat.const(dtypes.bool, True), UPat.var("barrier")), lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)), - (NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var), - (NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var), - (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)), + (UPat.load(UPat.var(), UPat.var(), UPat.var("var"), UPat.const(dtypes.bool, False)), lambda var: var), + (UPat.load(UPat.var(), UPat.var(), UPat.var("var"), UPat.const(dtypes.bool, False), UPat.var()), lambda var: var), + (UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("val"), UPat.const(dtypes.bool, True)), lambda buf,idx,val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda - (NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)), + (UPat.store(UPat.var(), UPat.var(), UPat.var(), UPat.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)), # remove NOOPs from SINK - (NOp(UOps.SINK, name="root"), + (UPat(UOps.SINK, name="root"), lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None), # ** move add consts to end (NOTE: this is still happening before constant folding) ** - (UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None), - (UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]), + (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None), + (UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]), lambda x,c1,y: (x+y)+c1), ]) @@ -446,22 +448,22 @@ def create_gate(root:UOp) -> Optional[UOp]: expander = PatternMatcher([ # create gate MUST BE BEFORE expander - (NOp(UOps.STORE, name="root"), create_gate), + (UPat(UOps.STORE, name="root"), create_gate), # do expansion - (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, - UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand), - (NOp(UOps.CONTRACT, name="con"), do_contract), + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, + UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand), + (UPat(UOps.CONTRACT, name="con"), do_contract), # remove EXPANDs from SINK - (NOp(UOps.SINK, name="root"), + (UPat(UOps.SINK, name="root"), lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None), # BARRIERs aren't actually expanded - (NOp(UOps.BARRIER, src=(NOp(UOps.EXPAND, name="ex"),)), + (UPat(UOps.BARRIER, src=(UPat(UOps.EXPAND, name="ex"),)), lambda ex: UOp(UOps.EXPAND, dtypes.void, (UOp(UOps.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)), # empty EXPAND is NOOP - (NOp(UOps.EXPAND, src=(NOp.var('x'),), arg=()), lambda x: x), + (UPat(UOps.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x), # EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU - (NOp(UOps.EXPAND, name="ex", src=tuple(NOp.var('x').gep(i)+NOp.var('y').gep(i) for i in range(256 if AMX else 8))), + (UPat(UOps.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))), lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)), ]) @@ -474,16 +476,16 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]: return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg) reducer = PatternMatcher([ - (NOp(UOps.REDUCE, name="root"), do_reduce), + (UPat(UOps.REDUCE, name="root"), do_reduce), # no ALU on vectorized dtypes - (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu), + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu), # delete_redundant_gates (after expand, is this still needed?) - (NOp(UOps.STORE, name="root"), delete_redundant_gates), + (UPat(UOps.STORE, name="root"), delete_redundant_gates), # late fixup of unfoldable image loads (UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), ]) -no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE}, name="x"), +no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \ if x.dtype.scalar() == dtypes.pyint else None)]) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4e298ab851..ec22eb562c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -155,7 +155,7 @@ reduceop_fusor = PatternMatcher([ # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.SWIZZLE, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes) - (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_swizzle_down_through_elementwise), + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a96e4085dd..66f317d39a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, Sequence, DefaultDict +from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib from enum import auto from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType from tinygrad.helpers import pretty_print, prod, getenv, all_same, HashEnum from tinygrad.shape.symbolic import Variable, sint @@ -335,7 +335,7 @@ class UOps(HashEnum): ENDIF = auto() BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} - +COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR} END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} @dataclass(frozen=True, eq=False) @@ -367,9 +367,6 @@ class UOp(MathTrait): def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg - def commutative(self) -> bool: - return (self.op is UOps.ALU and \ - self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}) # *** uop syntactic sugar @property def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1 @@ -381,10 +378,19 @@ class UOp(MathTrait): return ret.arg def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st) + def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b) def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) - def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i) - def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b) + def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar(), (self,), i) + @classmethod + def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore + @classmethod + def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src) + def alu(self, arg, *src:UOp): + out_dtype = (self, *src)[-1].dtype + if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None: + out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool + return type(self)(UOps.ALU, out_dtype, (self,)+src, arg) @classmethod @functools.lru_cache(None) def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b) @@ -395,15 +401,6 @@ class UOp(MathTrait): if dtype is not None and dtype != (sdtype := dtype.scalar()): return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore - def alu(self, arg, *src:UOp): - out_dtype = (self, *src)[-1].dtype - if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None: - out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool - return type(self)(UOps.ALU, out_dtype, (self,)+src, arg) - @classmethod - def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore - @classmethod - def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src) @functools.cached_property def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}} @property # parents with self @@ -608,39 +605,13 @@ def get_location() -> Tuple[str, int]: @functools.lru_cache(None) def lines(fn) -> List[str]: return open(fn).readlines() -@dataclass(frozen=True, eq=False, repr=False) # reuse repr from UOp -class NOp(UOp): - name: Optional[str] = None - # NOTE: this is fine because None dtype in NOp means any dtype is valid. - dtype: Optional[DType] = None # type: ignore - src: Tuple[NOp, ...] = tuple() - allow_any_len: bool = False - location: Tuple[str, int] = field(default_factory=get_location) - - @staticmethod - @functools.lru_cache(None) - def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name) - @staticmethod - @functools.lru_cache(None) - def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name) - - # this is needed so NOp has a different cache - @classmethod - @functools.lru_cache(None) - def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b) - - @functools.cached_property - def upat(self:NOp) -> UPat: - return UPat(name=self.name, dtype=self.dtype, location=self.location) if self.op is UOps.NOOP else \ - UPat(self.op, self.arg, (list if self.commutative() else tuple)([src.upat for src in self.src]) or None, self.name, - self.dtype, self.allow_any_len, location=self.location) - -class UPat: - def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, - name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None, +class UPat(MathTrait): + def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None, + src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None, + name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None): - self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,)) - self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,)) + self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op + self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype self.arg, self.name = arg, name self.in_src = src self.src: Any = None @@ -660,6 +631,31 @@ class UPat: upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0]) self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1) + @staticmethod + @functools.lru_cache(None) + def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name) + @staticmethod + @functools.lru_cache(None) + def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(UOps.CONST, dtype=dtype, name=name) + @staticmethod + @functools.lru_cache(None) + def const(dtype:Optional[DType], b:ConstType|Variable): return UPat(UOps.CONST, dtype=dtype, arg=b) + + # copied from UOp + def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) + def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) + def gep(self, i:int): return type(self)(UOps.GEP, None, (self,), i) + @classmethod + def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore + @classmethod + def store(cls, *src:UPat): return cls(UOps.STORE, dtypes.void, src) + + def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b) + def alu(self, arg, *src:UPat): + asrc = (self,)+src + return type(self)(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, + list(asrc) if arg in COMMUTATIVE else asrc, arg) + def printable(self:UPat) -> str: try: return lines(self.location[0])[self.location[1]-1].strip() @@ -687,8 +683,8 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: return res class PatternMatcher: - def __init__(self, patterns:Sequence[Tuple[Union[UPat, NOp], Callable]]): - self.patterns = [(p.upat if isinstance(p, NOp) else p, fxn) for p,fxn in patterns] + def __init__(self, patterns:List[Tuple[UPat, Callable]]): + self.patterns = patterns self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list) # uop is required, arg is optional for p,fxn in self.patterns: @@ -710,7 +706,7 @@ class PatternMatcher: TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 0) match_stats:Dict[UPat, List[Union[int, float]]] = dict() class TrackedPattenMatcher(PatternMatcher): - def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]): + def __init__(self, patterns:List[Tuple[UPat, Callable]]): super().__init__(patterns) for p,_ in self.patterns: if p not in match_stats: match_stats[p] = [0,0,0.0,0.0] diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index eede5da45d..d330922ee0 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -34,20 +34,21 @@ asm_for_op: Dict[Op, Callable] = { supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] shiftable_consts = set([2**i for i in range(64)]) ptx_matcher = PatternMatcher([ - (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), + (UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), src=[UPat(UOps.CONST, name="const"), UPat(name="mul")]), lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None), - (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), + (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), src=[UPat(UOps.CONST, name="const"), UPat(name="div")]), lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None), - (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)), - (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"), + (UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"), + lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)), + (UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat(name="x", dtype=dtypes.bool),UPat(name="y")), name="root"), lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x, UOp.const(dtypes.bool, True)), BinaryOps.CMPNE), y), BinaryOps.MUL)), - (UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"), + (UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(name="non_muls"), UPat(UOps.ALU, arg=BinaryOps.MUL, name="muls")], name="root"), lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)), - *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"), + *[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"), lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half))) for op in asm_for_op.keys() if op not in supports_half], (UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX), @@ -63,17 +64,17 @@ ptx_matcher = PatternMatcher([ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))), lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)), # ptr_ar (load/store) - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), - UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))), + (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), + UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))), lambda root, alu, const: UOp(root.op, root.dtype, (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), const*root.src[0].dtype.itemsize)+root.src[2:])), - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat(UOps.CONST, name="const"))), lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64), UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])), - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat(name="alu"))), # no const here lambda root, alu: UOp(root.op, root.dtype, (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),