remove nop, use upat [run_process_replay] (#6489)

* remove nop, use upat [run_process_replay]

* mypy passes

* no wonder nothing worked

* fixes
This commit is contained in:
George Hotz
2024-09-12 12:16:19 +08:00
committed by GitHub
parent f12f0857d8
commit 76487a3533
8 changed files with 165 additions and 171 deletions

View File

@@ -5,16 +5,16 @@ from tinygrad import dtypes, Device
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import NOp, PatternMatcher
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.codegen.lowerer import ast_to_uop
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding
from tinygrad.shape.shapetracker import ShapeTracker, View
simple_pm = PatternMatcher([
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(NOp.cvar('x') + NOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(NOp.cvar('x') * NOp.cvar('y') * NOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))

View File

@@ -7,7 +7,6 @@ from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UOps, UOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401
from tinygrad.ops import NOp
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
@@ -389,10 +388,6 @@ class TestUOpStr(unittest.TestCase):
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
assert_equiv_uops(sink, eval(str(sink)))
def test_nop_str(self):
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")
assert str(eval(str(a))) == str(a)
def test_variable_const(self):
# TODO: this is not possible after VALID.
uop = UOp(UOps.CONST, dtypes.int, (), arg=Variable("a",1,10))

View File

@@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase):
def test_arg(self):
matcher = PatternMatcher([
(UPat(UOps.CONST, 0, name="x"), lambda x: x),
(UPat(UOps.CONST, False, name="x"), lambda x: x),
(UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x),
(UPat(UOps.CONST, arg=0, name="x"), lambda x: x),
(UPat(UOps.CONST, arg=False, name="x"), lambda x: x),
(UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
])
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
@@ -47,7 +47,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_filter_arg(self):
matcher = PatternMatcher([
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)], name="x"),
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(UOps.CONST, dtypes.int, arg=1)

View File

@@ -3,7 +3,7 @@ from typing import Tuple, List
from tinygrad.dtype import dtypes, DType
from tinygrad.ops import UOp
TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64}
TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""

View File

@@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
from tinygrad.ops import NOp, UPat, PatternMatcher, graph_rewrite
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, all_same, partition
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
if TYPE_CHECKING: from tinygrad.renderer import Renderer
@@ -83,9 +83,9 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
float4_folding = PatternMatcher([
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
(UPat({UOps.BARRIER, UOps.SINK}, src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
(UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
(UPat(UOps.VECTORIZE, src=UPat(UOps.REDUCE), name="vec"), vectorize_reduce),
(UPat(UOps.VECTORIZE, src=UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}), name="vec"), vectorize_alu),
(UPat(UOps.VECTORIZE, src=UPat((UOps.ALU, UOps.CAST, UOps.BITCAST)), name="vec"), vectorize_alu),
])
# ***** mod *****
@@ -222,145 +222,147 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
# this is symbolic 2.0
constant_folder = PatternMatcher([
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat(UOps.ALU, BinaryOps.ADD, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
(UPat(UOps.ALU, BinaryOps.MUL, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
# VECTORIZE/GEP
(NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
*[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float,
src=(NOp.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in ([2, 4, 8, 16] + ([256] if AMX else []))],
*[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half,
src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
(UPat(UOps.GEP, src=(UPat(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
*[(UPat(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UPat(UOps.GEP, dtypes.float,
src=(UPat.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in ([2, 4, 8, 16] + ([256] if AMX else []))],
*[(UPat(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UPat(UOps.GEP, dtypes.half,
src=(UPat.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
# tensor core with a 0 input is acc
*[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))),
*[(UPat(UOps.WMMA, src=(UPat(UOps.VECTORIZE, src=tuple(UPat.const(None, 0.0) for _ in range(i))), UPat.var(), UPat.var('acc'))),
lambda acc: acc) for i in [2, 4, 8]],
*[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))),
*[(UPat(UOps.WMMA, src=(UPat.var(), UPat(UOps.VECTORIZE, src=tuple(UPat.const(None, 0.0) for _ in range(i))), UPat.var('acc'))),
lambda acc: acc) for i in [2, 4, 8]],
# tensor core cleanups
*[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),)
*[(UPat(UOps.REDUCE, src=(UPat(UOps.EXPAND, src=tuple(UPat(UOps.GEP, dtypes.float, src=(UPat.var('x'),), arg=i) for i in range(j)), name="expand"),)
,name="reduce", allow_any_len=True), reduce_before_expand) for j in ([2,4,8] + ([16,256] if AMX else []))],
(NOp.var("add") + NOp(UOps.WMMA, name="wmma"),
(UPat.var("add") + UPat(UOps.WMMA, name="wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
(NOp(UOps.ALU, dtype=dtypes.uint64, src=(NOp.var("x"), NOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
# extra arange loop folding because we don't fold adds. TODO: fold adds
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
NOp.var("idx2") + NOp.var("idx3"))
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
NOp.var("idx2"))
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
UPat.var("idx2") + UPat.var("idx3")).lt(UPat.cvar("compval"))
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
UPat.var("idx2")).lt(UPat.cvar("compval"))
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (reduce)
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
.lt(UPat.cvar("compval"))
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (unrolled)
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),),
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# unrolled arange div folding
(NOp.var('divs') + NOp.cvar('c'), fold_unrolled_divs),
(UPat.var('divs') + UPat.cvar('c'), fold_unrolled_divs),
# indexing (with a multiply offset)!
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
(UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).cast()*
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('add')+UPat.var('mul')*UPat(UOps.RANGE, name="rng")), name="ld"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.RANGE, name="rng")), name="ld"),),
(UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).cast()*
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat(UOps.RANGE, name="rng")), name="ld"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).where(
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"), NOp.const(None, 0.0)),),
(UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).where(
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('add')+UPat.var('mul')*UPat(UOps.RANGE, name="rng")), name="ld"), UPat.const(None, 0.0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
# max folding
(NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
(UPat.max(UPat.var('x'), UPat.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# GEP/CAST const rules
(NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.GEP, src=(UPat.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const_like(c.arg)),
# a conditional with the same results either way is a noop, also fold const conditionals
(NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val),
(NOp.cvar('gate').where(NOp.var('c0'), NOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar('gate').where(UPat.var('c0'), UPat.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
# ** self folding **
# cast NOOP (NOTE: it's str to deal with PtrDType)
(NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP
(NOp.var('x') + 0, lambda x: x), # x+0 -> x
(NOp.var('x') * 1, lambda x: x), # x*1 -> x
(NOp.var('x') // NOp.var('x'), lambda x: x.const_like(1)), # x//x -> 1
(NOp.var('x') // 1, lambda x: x), # x//1 -> x
(NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
(NOp.var('x') / NOp.var('x'), lambda x: x.const_like(1)), # x/x -> 1
((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
(NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c),
(NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(UPat(UOps.REDUCE, src=(UPat.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP
(UPat.var('x') + 0, lambda x: x), # x+0 -> x
(UPat.var('x') * 1, lambda x: x), # x*1 -> x
(UPat.var('x') // UPat.var('x'), lambda x: x.const_like(1)), # x//x -> 1
(UPat.var('x') // 1, lambda x: x), # x//1 -> x
(UPat.var('x') // -1, lambda x: -x), # x//-1 -> -x
(UPat.var('x') / UPat.var('x'), lambda x: x.const_like(1)), # x/x -> 1
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
(UPat.var('x', dtype=dtypes.bool) & UPat.cvar('c'), lambda x,c: x if c.arg else c),
(UPat.var('x', dtype=dtypes.bool) | UPat.cvar('c'), lambda x,c: c if c.arg else x),
# ** zero folding **
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN
(NOp.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
(UPat.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# min==max -> CONST (slow!)
(UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat((UOps.ALU, UOps.DEFINE_VAR), name='x'), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# ** load/store folding **
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.load(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
# ** two stage add/mul folding **
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+(c1+c2)),
((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x+(c1+c2)),
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
# *** rules from symbolic ***
# ** lt **
# c0*x<c1 for positive int c0,c1
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
((UPat.cvar('c0')*UPat.var('x')).lt(UPat.cvar('c1')),
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
((UPat.cvar('c0')*UPat.var('x')).lt(UPat.cvar('c1')),
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if dtypes.is_int(x.dtype) and c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# mul add lt
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
(((UPat.cvar('c0')*UPat.var('x'))+UPat.var('x2')).lt(UPat.cvar('c1')),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
# generic lt folding
(NOp.var('x').lt(NOp.cvar('c')),
(UPat.var('x').lt(UPat.cvar('c')),
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
# ** div **
# # div folding
(NOp.var('x') // NOp.cvar('c'), lambda x,c:
(UPat.var('x') // UPat.cvar('c'), lambda x,c:
newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None),
# ** mod **
# mod folding
(NOp.var('x') % NOp.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
(UPat.var('x') % UPat.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
# mul mod
((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None),
((UPat.cvar('c0')*UPat.var('x')) % UPat.cvar('c1'), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None),
# ** combine terms **
(NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x), # (x%c)+(x//c)*c = x
(NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(NOp.var("x") + NOp.var("x") * NOp.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
(NOp.var("x") + NOp.var("x"), lambda x: x*2), # (x+x)-> x*2
((NOp.var("x") // NOp.cvar("c0")) // NOp.cvar("c1"), lambda x,c0,c1: x//(c0*c1)), # (x//c0)//c1 -> x//(c0*c1)
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0
(UPat.var('x')%UPat.cvar('c')+(UPat.var('x')//UPat.cvar('c'))*UPat.cvar('c'), lambda x,c: x), # (x%c)+(x//c)*c = x
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("x") // UPat.cvar("c0")) // UPat.cvar("c1"), lambda x,c0,c1: x//(c0*c1)), # (x//c0)//c1 -> x//(c0*c1)
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
((UPat.cvar("c0") + UPat.var("x")).lt(UPat.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((NOp.var("x") + NOp.var("y")) * NOp.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None),
((UPat.var("x") + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None),
# x!=0 -> (bool)x
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
(UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
# bitwise noops
((NOp.var("x") & NOp.var("x")), lambda x: x),
((NOp.var("x") | NOp.var("x")), lambda x: x),
((UPat.var("x") & UPat.var("x")), lambda x: x),
((UPat.var("x") | UPat.var("x")), lambda x: x),
# TODO: can do the invert of this (flip alt/load) when we fix double ops
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("gate").where(NOp.var("alt"), NOp.load(NOp.var("buf"), NOp.var("idx")))),
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat.var("buf"), UPat.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# fold gated LOAD/STORE
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
(UPat.load(UPat.var("buf"), UPat.var("idx"), UPat.var("var"), UPat.const(dtypes.bool, True)),
lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(UPat.load(UPat.var("buf"), UPat.var("idx"), UPat.var("var"), UPat.const(dtypes.bool, True), UPat.var("barrier")),
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var),
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)),
(UPat.load(UPat.var(), UPat.var(), UPat.var("var"), UPat.const(dtypes.bool, False)), lambda var: var),
(UPat.load(UPat.var(), UPat.var(), UPat.var("var"), UPat.const(dtypes.bool, False), UPat.var()), lambda var: var),
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("val"), UPat.const(dtypes.bool, True)),
lambda buf,idx,val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
(NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
(UPat.store(UPat.var(), UPat.var(), UPat.var(), UPat.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
# remove NOOPs from SINK
(NOp(UOps.SINK, name="root"),
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
# ** move add consts to end (NOTE: this is still happening before constant folding) **
(UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
(UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]),
lambda x,c1,y: (x+y)+c1),
])
@@ -446,22 +448,22 @@ def create_gate(root:UOp) -> Optional[UOp]:
expander = PatternMatcher([
# create gate MUST BE BEFORE expander
(NOp(UOps.STORE, name="root"), create_gate),
(UPat(UOps.STORE, name="root"), create_gate),
# do expansion
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
(NOp(UOps.CONTRACT, name="con"), do_contract),
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
(UPat(UOps.CONTRACT, name="con"), do_contract),
# remove EXPANDs from SINK
(NOp(UOps.SINK, name="root"),
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
# BARRIERs aren't actually expanded
(NOp(UOps.BARRIER, src=(NOp(UOps.EXPAND, name="ex"),)),
(UPat(UOps.BARRIER, src=(UPat(UOps.EXPAND, name="ex"),)),
lambda ex: UOp(UOps.EXPAND, dtypes.void, (UOp(UOps.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
# empty EXPAND is NOOP
(NOp(UOps.EXPAND, src=(NOp.var('x'),), arg=()), lambda x: x),
(UPat(UOps.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
# EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
(NOp(UOps.EXPAND, name="ex", src=tuple(NOp.var('x').gep(i)+NOp.var('y').gep(i) for i in range(256 if AMX else 8))),
(UPat(UOps.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
])
@@ -474,16 +476,16 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
reducer = PatternMatcher([
(NOp(UOps.REDUCE, name="root"), do_reduce),
(UPat(UOps.REDUCE, name="root"), do_reduce),
# no ALU on vectorized dtypes
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu),
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu),
# delete_redundant_gates (after expand, is this still needed?)
(NOp(UOps.STORE, name="root"), delete_redundant_gates),
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
# late fixup of unfoldable image loads
(UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
])
no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE}, name="x"),
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \
if x.dtype.scalar() == dtypes.pyint else None)])

View File

@@ -155,7 +155,7 @@ reduceop_fusor = PatternMatcher([
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.SWIZZLE, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_swizzle_down_through_elementwise),
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, Sequence, DefaultDict
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
from enum import auto
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
from tinygrad.helpers import pretty_print, prod, getenv, all_same, HashEnum
from tinygrad.shape.symbolic import Variable, sint
@@ -335,7 +335,7 @@ class UOps(HashEnum):
ENDIF = auto()
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
@dataclass(frozen=True, eq=False)
@@ -367,9 +367,6 @@ class UOp(MathTrait):
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self):
return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg
def commutative(self) -> bool:
return (self.op is UOps.ALU and \
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
# *** uop syntactic sugar
@property
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
@@ -381,10 +378,19 @@ class UOp(MathTrait):
return ret.arg
def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar(), (self,), i)
@classmethod
def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
@classmethod
def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src)
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return type(self)(UOps.ALU, out_dtype, (self,)+src, arg)
@classmethod
@functools.lru_cache(None)
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b)
@@ -395,15 +401,6 @@ class UOp(MathTrait):
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return type(self)(UOps.ALU, out_dtype, (self,)+src, arg)
@classmethod
def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
@classmethod
def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
@property # parents with self
@@ -608,39 +605,13 @@ def get_location() -> Tuple[str, int]:
@functools.lru_cache(None)
def lines(fn) -> List[str]: return open(fn).readlines()
@dataclass(frozen=True, eq=False, repr=False) # reuse repr from UOp
class NOp(UOp):
name: Optional[str] = None
# NOTE: this is fine because None dtype in NOp means any dtype is valid.
dtype: Optional[DType] = None # type: ignore
src: Tuple[NOp, ...] = tuple()
allow_any_len: bool = False
location: Tuple[str, int] = field(default_factory=get_location)
@staticmethod
@functools.lru_cache(None)
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
# this is needed so NOp has a different cache
@classmethod
@functools.lru_cache(None)
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b)
@functools.cached_property
def upat(self:NOp) -> UPat:
return UPat(name=self.name, dtype=self.dtype, location=self.location) if self.op is UOps.NOOP else \
UPat(self.op, self.arg, (list if self.commutative() else tuple)([src.upat for src in self.src]) or None, self.name,
self.dtype, self.allow_any_len, location=self.location)
class UPat:
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None,
class UPat(MathTrait):
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None,
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
self.arg, self.name = arg, name
self.in_src = src
self.src: Any = None
@@ -660,6 +631,31 @@ class UPat:
upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0])
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
@staticmethod
@functools.lru_cache(None)
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(UOps.CONST, dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def const(dtype:Optional[DType], b:ConstType|Variable): return UPat(UOps.CONST, dtype=dtype, arg=b)
# copied from UOp
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return type(self)(UOps.GEP, None, (self,), i)
@classmethod
def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
@classmethod
def store(cls, *src:UPat): return cls(UOps.STORE, dtypes.void, src)
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
def alu(self, arg, *src:UPat):
asrc = (self,)+src
return type(self)(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype,
list(asrc) if arg in COMMUTATIVE else asrc, arg)
def printable(self:UPat) -> str:
try:
return lines(self.location[0])[self.location[1]-1].strip()
@@ -687,8 +683,8 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
return res
class PatternMatcher:
def __init__(self, patterns:Sequence[Tuple[Union[UPat, NOp], Callable]]):
self.patterns = [(p.upat if isinstance(p, NOp) else p, fxn) for p,fxn in patterns]
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
self.patterns = patterns
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list)
# uop is required, arg is optional
for p,fxn in self.patterns:
@@ -710,7 +706,7 @@ class PatternMatcher:
TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 0)
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
class TrackedPattenMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
super().__init__(patterns)
for p,_ in self.patterns:
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]

View File

@@ -34,20 +34,21 @@ asm_for_op: Dict[Op, Callable] = {
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
shiftable_consts = set([2**i for i in range(64)])
ptx_matcher = PatternMatcher([
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
(UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
src=[UPat(UOps.CONST, name="const"), UPat(name="mul")]),
lambda root, mul, const: UOp(UOps.ALU, root.dtype,
(mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
src=[UPat(UOps.CONST, name="const"), UPat(name="div")]),
lambda root, div, const: UOp(UOps.ALU, root.dtype,
(div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None),
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
(UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"),
lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
(UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat(name="x", dtype=dtypes.bool),UPat(name="y")), name="root"),
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x, UOp.const(dtypes.bool, True)), BinaryOps.CMPNE), y), BinaryOps.MUL)),
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(name="non_muls"), UPat(UOps.ALU, arg=BinaryOps.MUL, name="muls")], name="root"),
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
for op in asm_for_op.keys() if op not in supports_half],
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
@@ -63,17 +64,17 @@ ptx_matcher = PatternMatcher([
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
# ptr_ar (load/store)
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
lambda root, alu, const: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
const*root.src[0].dtype.itemsize)+root.src[2:])),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
UPat(UOps.CONST, name="const"))),
lambda root, const: UOp(root.op, root.dtype,
(root.src[0].cast(dtypes.int64),
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
UPat(name="alu"))), # no const here
lambda root, alu: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),