Revert "s/UPat/Pat (#7506)" [pr] (#7517)

* Revert "s/UPat/Pat (#7506)"

This reverts commit 400011a8c1.

* fix
This commit is contained in:
chenyu
2024-11-03 16:33:02 -05:00
committed by GitHub
parent e641bbc859
commit 7758f7211b
12 changed files with 347 additions and 347 deletions

View File

@@ -2,7 +2,7 @@ import unittest, pickle, types
import numpy as np
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import PatternMatcher, Pat, UOp
from tinygrad.ops import PatternMatcher, UPat, UOp
class TestPickle(unittest.TestCase):
def test_pickle_code_object(self):
@@ -12,7 +12,7 @@ class TestPickle(unittest.TestCase):
self.assertEqual(fxn(2), 4)
def test_pickle_pattern_matcher(self):
pm = PatternMatcher([(Pat.cvar('x'), lambda x: x*2)])
pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
sink = UOp.const(dtypes.int, 2)
tt = pm.rewrite(sink)
pm_str = pickle.dumps(pm)

View File

@@ -3,7 +3,7 @@ import unittest, time
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo
from tinygrad.ops import Pat, PatternMatcher
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym
@@ -11,10 +11,10 @@ from tinygrad.codegen.linearize import linearize_uop
from tinygrad.shape.shapetracker import ShapeTracker, View
simple_pm = PatternMatcher([
(Pat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(Pat.cvar('x') + Pat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(Pat.cvar('x') * Pat.cvar('y') * Pat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((Pat.var('x') + Pat.cvar('c1')) + Pat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import Ops, UOp, Pat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
@@ -441,13 +441,13 @@ class TestIndexingOrdering(unittest.TestCase):
stores = [st for st in uops if st.op is Ops.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
class TestPatHelpers(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location Pat files created in test/*?
test_upat = Pat(Ops.CONST, dtypes.bool)
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
if __name__ == '__main__':

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, Optional
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, Pat, \
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
@@ -25,7 +25,7 @@ class TestViz(unittest.TestCase):
def test_viz_simple(self):
pm = PatternMatcher([
(Pat.var("x")*1, lambda x:x),
(UPat.var("x")*1, lambda x:x),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a*1, pm)
@@ -34,8 +34,8 @@ class TestViz(unittest.TestCase):
def test_rewrite_twice(self):
pm = PatternMatcher([
(Pat.var("x")+Pat.var("x"), lambda x:x*2),
(Pat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)
@@ -51,14 +51,14 @@ class TestViz(unittest.TestCase):
ctx[x] = None
return UOp.store(*x.src, x)
pm = PatternMatcher([
(Pat(Ops.LOAD, name="x"), store_load),
(UPat(Ops.LOAD, name="x"), store_load),
])
uops = helper_test_viz(a+b, pm, {})
self.assertEqual(len(uops), 2)
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
def test_track_rewrites(self):
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites(named=True)
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
@@ -74,7 +74,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(m.upats), 0)
def test_track_rewrites_with_exception(self):
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites()
def do_rewrite(x:UOp):
x = graph_rewrite(x, simple) # NOTE: viz tracks this

View File

@@ -1,11 +1,11 @@
import unittest, itertools
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.ops import PatternMatcher, Pat
from tinygrad.ops import PatternMatcher, UPat
class TestPatternMatcher(unittest.TestCase):
def test_simple_match(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.int, arg=1)
self.assertEqual(matcher.rewrite(c1), c1)
@@ -16,7 +16,7 @@ class TestPatternMatcher(unittest.TestCase):
#print(x,y,z)
if y is not None: return a+y
matcher = PatternMatcher([
(Pat.var("a")+Pat.any(Pat.var("x"), Pat.var("y"), Pat.var("z")), test),
(UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test),
])
v1 = UOp.variable("a", 0, 10)
v2 = UOp.variable("b", 0, 10)
@@ -31,7 +31,7 @@ class TestPatternMatcher(unittest.TestCase):
match_cnt += 1
assert len(x.src) == 0
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
c1 = matcher.rewrite(c1)
@@ -43,7 +43,7 @@ class TestPatternMatcher(unittest.TestCase):
ctx.append(True)
assert len(x.src) == 0
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
ctx = []
@@ -52,14 +52,14 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(len(ctx), 1)
def test_uop(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x"), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_uop_set(self):
matcher = PatternMatcher([(Pat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.bool, arg=False)
c2 = UOp(Ops.CAST, dtypes.int, (c1,))
c3 = UOp(Ops.CONST, dtypes.float, arg=1.0)
@@ -70,9 +70,9 @@ class TestPatternMatcher(unittest.TestCase):
def test_arg(self):
matcher = PatternMatcher([
(Pat(Ops.CONST, arg=0, name="x"), lambda x: x),
(Pat(Ops.CONST, arg=False, name="x"), lambda x: x),
(Pat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
])
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
@@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_filter_arg(self):
matcher = PatternMatcher([
(Pat(Ops.ALU, arg=BinaryOps.MUL, src=[Pat(Ops.CONST, name="c"), Pat(Ops.CONST, arg=2)], name="x"),
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
@@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c5), c5)
def test_dup_name(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST, name="y"), Pat(Ops.CONST, name="y"))), lambda x, y: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
@@ -114,14 +114,14 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c2), c1)
def test_dtype(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_dtype_set(self):
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0)
@@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_src_one(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST), Pat(Ops.CONST))), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -140,7 +140,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c2), None)
# that CONST/ALU -> ALU/CONST rewrite is now instant
"""
matcher = PatternMatcher([(Pat(UOps.ALU, name="x", src=(Pat(UOps.CONST), Pat(UOps.ALU))), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), None)
@@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase):
"""
def test_src_permutations(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=[Pat(Ops.CONST), Pat(Ops.ALU)]), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c6), None)
def test_src_repeat(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=Pat(Ops.CONST)), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
@@ -188,16 +188,16 @@ class TestPatternMatcher(unittest.TestCase):
u1 = (c1 + c2) + c1
u2 = (c2 + c1) + c1
matcher = PatternMatcher([
(Pat(Ops.ALU, src=[Pat(Ops.ALU, src=[Pat(name='a'), Pat(name='b')]), Pat(name='b')]), lambda a,b: b)
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
])
self.assertIsNotNone(matcher.rewrite(u1))
self.assertIsNotNone(matcher.rewrite(u2))
def _assert_eq_upat(self, a:Pat, b:Pat):
def _assert_eq_upat(self, a:UPat, b:UPat):
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
assert (a.name, type(a.src)) == (b.name, type(b.src))
def simple_src(u:Pat):
def simple_src(u:UPat):
if u.src is None: return []
if isinstance(u.src, itertools.repeat): return next(u.src[0])
return u.src[0]

View File

@@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import variable_to_uop
from tinygrad.dtype import dtypes
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, Pat, sint, identity_element
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten
@@ -109,7 +109,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
# TODO: check has_valid in Pat, not here
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not Ops.CONST or valid.arg is not True
buf = x.src[0]
if x.op is Ops.LOAD:
@@ -127,10 +127,10 @@ def lower_load_store(ctx: IndexContext, x: UOp):
return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
pm_lowerer = PatternMatcher([
(Pat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(Pat(Ops.VALID, src=(Pat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
(Pat((Ops.LOAD, Ops.STORE), src=(Pat(), Pat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
])
def do_reduce(ctx:List[int], root:UOp):
@@ -141,7 +141,7 @@ def do_reduce(ctx:List[int], root:UOp):
just_reduce = PatternMatcher([
# do reduce
(Pat(Ops.REDUCE, name="root"), do_reduce),
(UPat(Ops.REDUCE, name="root"), do_reduce),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:

View File

@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict,
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, Pat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -77,10 +77,10 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
buf_idx_pat = Pat(Ops.INDEX, src=(Pat.var("buf"),), allow_any_len=True)
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
float4_folding = PatternMatcher([
(Pat(Ops.VECTORIZE, src=Pat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
(Pat((Ops.BARRIER, Ops.SINK), src=Pat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
])
# ***** image load valid simplification *****
@@ -124,24 +124,24 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
powers_of_two = {2**i:i for i in range(64)}
@functools.lru_cache(None)
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: List[Tuple[Pat, Callable]] = [(Pat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(Pat.var("d"),), arg=op), f) for op,f in \
pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
if BinaryOps.AND in ops:
pat += [(Pat(Ops.ALU, arg=BinaryOps.MOD, src=(Pat.var('base'), Pat.cvar("const"))),
pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
pat += [
(Pat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[Pat.cvar("const"), Pat.var("mul")]), lambda mul, const:
(UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
(Pat(Ops.ALU, arg=BinaryOps.IDIV, src=(Pat.var("div"), Pat.cvar("const"))), lambda div, const:
(UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
if UnaryOps.NEG in ops:
pat += [(Pat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
if BinaryOps.SUB in ops: pat += [(Pat.var('x')+Pat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
if TernaryOps.MULACC in ops:
pat += [(Pat.var('a')*Pat.var('b')+Pat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
return PatternMatcher(pat)
# ***** threefry *****
@@ -231,79 +231,79 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
acc_pat, rng_pat = Pat(Ops.DEFINE_ACC, name="acc"), Pat(Ops.RANGE, name="rng")
rng_aug = Pat.any(rng_pat, Pat.var("add")+rng_pat, Pat.var("mul")*rng_pat, Pat.var("add")+Pat.var("mul")*rng_pat)
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
index_load = Pat.var("buf").index(rng_aug).load(name="ld")
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
arange_augrng = Pat.any(rng_aug, rng_aug+Pat.var("idx2"), rng_aug+Pat.var("idx2")+Pat.var("idx3"), Pat(Ops.VECTORIZE, name="vec", src=rng_aug))
arange_m = arange_augrng.lt(Pat.cvar("compval")).ne(Pat(Ops.CONST, name="ne", arg=True)).where(Pat.cvar("multconst"), Pat.const(None, 0))
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
# this is symbolic 2.0
sym = symbolic_flat+PatternMatcher([
# self ASSIGN is just self
(Pat(Ops.ASSIGN, src=(Pat.var('x'), Pat.var('x'))), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# ASSIGN to global is just self
(Pat(Ops.ASSIGN, src=(Pat(Ops.DEFINE_GLOBAL), Pat.var("x"))), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
# VECTORIZE/CONST, VECTORIZE/GEP
(Pat(Ops.VECTORIZE, src=Pat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(Pat(Ops.VECTORIZE, src=Pat(Ops.GEP, src=(Pat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# reorder ALU/VECTORIZE
(Pat(Ops.ALU, src=(Pat(Ops.VECTORIZE, src=Pat(name='x')), Pat(Ops.VECTORIZE, src=Pat(name='y'))), name='alu'),
(UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
# VECTORIZE of a single element is just that element
(Pat(Ops.VECTORIZE, src=(Pat(name='x'),)), lambda x: x),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# VECTORIZE void is SINK
(Pat(Ops.VECTORIZE, dtype=dtypes.void, src=Pat(Ops.BARRIER, name='b')), lambda b: b),
(Pat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(Pat(Ops.GEP, src=(Pat(Ops.GEP, name='g2'),), name='g1'),
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
(Pat(Ops.GEP, src=(Pat(Ops.VECTORIZE, name="vec"),), name="gep"),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(Pat(Ops.GEP, src=(Pat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(Pat(Ops.GEP, src=(Pat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(Pat(Ops.GEP, src=(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
(UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
# push some GEPs through WMMAs
(Pat(Ops.GEP, src=(Pat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
# tensor core with a 0 input is acc
(Pat(Ops.WMMA, src=(Pat.const(None, 0.0), Pat.var(), Pat.var("acc"))), lambda acc: acc),
(Pat(Ops.WMMA, src=(Pat.var(), Pat.const(None, 0.0), Pat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
# tensor core cleanups
(Pat.var("add") + Pat(Ops.WMMA, name="wmma"),
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
(Pat(Ops.ALU, dtype=dtypes.uint64, src=(Pat.var("x"), Pat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
(UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
# arange loop folding
(acc_pat.assign(Pat.any(arange_m, arange_m+Pat.var("extra"))+acc_pat), loop_collapse),
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
# indexing, with cast or where
(acc_pat.assign(Pat.var("idx").eq(Pat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(Pat.var("idx").eq(Pat(Ops.RANGE, name="rng")).where(index_load, Pat.const(None, 0.0))+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
# parentless reduce
(acc_pat.assign(Pat(Ops.ALU, src=[acc_pat, Pat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
(acc_pat.assign(Pat(Ops.ALU, src=[acc_pat, Pat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
# ** self folding **
(Pat(Ops.DEFINE_ACC, src=(Pat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(Pat(Ops.ASSIGN, src=(Pat.cvar(),Pat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
# x!=0 -> (bool)x
(Pat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
(UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** load/store folding **
(Pat.store(Pat(Ops.INDEX, name="index"), Pat.load(Pat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
(Pat.store(Pat(Ops.INDEX, name="index"), Pat.var("gate").where(Pat.var("alt"), Pat.load(Pat(Ops.INDEX, name="index")))),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
# fold gated LOAD/STORE
(Pat().index(Pat(), Pat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
(Pat().index(Pat(), Pat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
(Pat(Ops.LOAD, src=(Pat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
(Pat(Ops.STORE, src=(Pat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
(UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
(UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
# remove NOOPs from SINK
(Pat(Ops.SINK, name="root"),
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
# remove EXPANDs from SINK/BARRIER
(Pat(Ops.BARRIER, src=(Pat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
(Pat(Ops.SINK, name="root"),
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.EXPAND} else (x,) for x in root.src)), root.arg)
if any(x.op in {Ops.SINK, Ops.EXPAND} for x in root.src) else None),
])
@@ -400,21 +400,21 @@ def create_gate(root:UOp) -> Optional[UOp]:
expander = PatternMatcher([
# double expand
(Pat(Ops.EXPAND, name="outer", src=(Pat(Ops.EXPAND, name="inner"),)),
(UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)),
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand),
(Pat(Ops.CONTRACT, name="con"), do_contract),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
(Pat(Ops.VECTORIZE, src=Pat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
# BARRIERs aren't actually expanded
(Pat(Ops.BARRIER, src=(Pat(Ops.EXPAND, name="ex"),)),
(UPat(Ops.BARRIER, src=(UPat(Ops.EXPAND, name="ex"),)),
lambda ex: UOp(Ops.EXPAND, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
# empty EXPAND is NOOP
(Pat(Ops.EXPAND, src=(Pat.var('x'),), arg=()), lambda x: x),
(UPat(Ops.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
# EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
(Pat(Ops.EXPAND, name="ex", src=tuple(Pat.var('x').gep(i)+Pat.var('y').gep(i) for i in range(256 if AMX else 8))),
(UPat(Ops.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
lambda ex,x,y: UOp(Ops.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
])
@@ -433,10 +433,10 @@ def no_vectorized_acc(acc:UOp):
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
(Pat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(Pat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
(Pat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
])
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]:
@@ -446,14 +446,14 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optio
load_store_indexing = PatternMatcher([
# late fixup of unfoldable image loads
(Pat(Ops.LOAD, src=(Pat.var("buf"), Pat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
# simplify valid
(Pat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
(UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
# image load valid idx simplification
(Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var("start_idx"), Pat.var("valid"))), simplify_valid_load),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# delete_redundant_gates (after expand)
(Pat(Ops.STORE, src=(Pat.any(stidx:=Pat.var("buf").index(Pat.var("idx"), Pat.var("store_gate")), stidx.cast().named("cast")),
Pat.var("val"))), delete_redundant_gates),
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val"))), delete_redundant_gates),
])
def idx_load_store(x:UOp):
@@ -466,9 +466,9 @@ def idx_load_store(x:UOp):
migrate_indexing = PatternMatcher([
# use indexing for LOAD/STORE
(Pat((Ops.LOAD, Ops.STORE), src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
# create gate MUST BE BEFORE expander
(Pat(Ops.STORE, name="root"), create_gate),
(UPat(Ops.STORE, name="root"), create_gate),
])
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp:
@@ -478,13 +478,13 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp
pm_render = PatternMatcher([
# for rendering, we use explicit VECTORIZE
(Pat(Ops.CONST, name='c'),
(UPat(Ops.CONST, name='c'),
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(Pat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(Pat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(Pat(Ops.VECTORIZE, src=(Pat(name='x'),)), lambda x: x),
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# move masks of loads/stores
(Pat((Ops.LOAD, Ops.STORE), src=(Pat.any(masked_index:=Pat(Ops.INDEX, src=(Pat(name="buf"), Pat(name="idx"), Pat(name="mask"))),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
])

View File

@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, Pat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -138,27 +138,27 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
merge_views = PatternMatcher([(Pat(Ops.VIEW, src=(Pat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view before ALU
(Pat(Ops.VIEW, src=(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"),
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
])
# push VIEW to stores
view_right = merge_views+PatternMatcher([
# ASSIGN can override st
(Pat(Ops.STORE, src=(Pat.var("b"), Pat.var("st"), Pat(Ops.ASSIGN, name="a"))),
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
# VIEW on a reduce creates a new VIEW
(Pat(Ops.VIEW, src=(Pat(Ops.REDUCE_AXIS, src=Pat.var("rsrc"), name="r"),), name="view"), view_r),
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
(Pat(Ops.REDUCE_AXIS, src=(Pat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
(Pat(Ops.REDUCE_AXIS, src=(Pat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
# ** ScheduleItem context builder
@@ -181,26 +181,26 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(Pat(Ops.BUFFER, name="x"), _append_buf)])
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
if b in ctx.assigned: ctx.assign_preloads.append(b)
return x.replace(op=Ops.LOAD)
to_si = PatternMatcher([
(Pat(Ops.VIEW, name="x"), _append_st_vars),
(Pat(Ops.PRELOAD, src=(Pat.var("b"), Pat()), name="x"), _append_preload),
(Pat(Ops.CONTIGUOUS, src=(Pat.var("x"),)), lambda ctx,x: x),
(Pat(Ops.SINK, src=(Pat.store(Pat(), Pat(), Pat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
])
# ** fusion
lazy = PatternMatcher([
(Pat.load(b:=Pat.var("b"), Pat(), Pat.store(b, Pat(), Pat.var("v"))), lambda ctx,b,v: v),
(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v),
])
multioutput = PatternMatcher([(Pat.load(Pat.var("b"), Pat()), lambda ctx,b: ctx.get(b)),])
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
# fuse and fold store -> loads
@@ -234,14 +234,14 @@ def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
ctx[b] = store
return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop()))
def PatLoadStore(to_store=Pat()): return Pat.load(b:=Pat.var("b"), Pat(), Pat.store(b, Pat(), to_store, name="store"), name="load")
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
do_realize = PatternMatcher([
# always realize meta ops
(PatLoadStore(Pat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize),
(Pat((Ops.COPY, Ops.BUFFER_VIEW), src=(Pat.var("u"), Pat.any(PatLoadStore(), PatLoadStore().view(name="v"))), name="root"),
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize),
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),)
])
break_sched = PatternMatcher([(PatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:

View File

@@ -205,7 +205,7 @@ def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else ls
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
# used for UOp and Pat
# used for UOp and UPat
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
def dfs(x:Any, cache:dict):
for s in srcfn(x) or []:
@@ -514,7 +514,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
def get_location() -> Tuple[str, int]:
frm = sys._getframe(1)
# find the real frame in the file that has the Pat, TODO: is there a better way to do this?
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py",
"lowerer.py", "cstyle.py"}:
frm = frm.f_back
@@ -523,75 +523,75 @@ def get_location() -> Tuple[str, int]:
def lines(fn) -> List[str]:
with open(fn) as f: return f.readlines()
class Pat(MathTrait):
class UPat(MathTrait):
__slots__ = ["op", "dtype", "arg", "name", "src"]
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
src:Optional[Union[Tuple[Pat, ...], List[Pat], Pat]]=None, arg:Any=None,
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None,
custom_early_reject:Optional[Set[Tuple[Ops, Any]]]=None):
self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else op
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
self.src: Any = None
assert self.name != "ctx", "Pat can't be named ctx"
assert self.name != "ctx", "UPat can't be named ctx"
# try all permutations if it's a list
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
# only one if it's a tuple
elif isinstance(src, tuple): self.src = [src]
# repeat if it's a Pat
elif isinstance(src, Pat): self.src = [itertools.repeat(src)]
# repeat if it's a UPat
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
self.allowed_len: int = -1 if allow_any_len or isinstance(src, Pat) or src is None else len(src)
self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
self.location = location or get_location()
if custom_early_reject is not None: self.early_reject = custom_early_reject
else:
upat_match = [src] if isinstance(src, Pat) else ([] if src is None else self.src[0])
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
def named(self, name:str): return Pat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
@staticmethod
def any(*src): return PatAny(src=src)
def any(*src): return UPatAny(src=src)
@staticmethod
@functools.lru_cache(None)
def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return Pat(dtype=dtype, name=name)
def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
return Pat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
@staticmethod
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return Pat(Ops.CONST, dtype=dtype, arg=b)
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
# copied from UOp
def index(self, idx:Pat, valid:Optional[Pat]=None): return Pat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return Pat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None): return Pat(Ops.CAST, dtype, (self,))
def bitcast(self, dtype=None): return Pat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int): return Pat(Ops.GEP, None, (self,), (i,))
def load(self, *src:Pat, **kwargs): return Pat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:Pat, **kwargs): return Pat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:Pat): return Pat(Ops.ASSIGN, self.dtype, (self,x))
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
def const_like(self, b:ConstLike): return Pat.const(self.dtype, cast(ConstType, b))
def alu(self, arg, *src:Pat):
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, arg, *src:UPat):
asrc = (self,)+src
return Pat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
return UPat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
def printable(self:Pat) -> str:
def printable(self:UPat) -> str:
try: return lines(self.location[0])[self.location[1]-1].strip()
except FileNotFoundError: return "<missing>"
def __repr__(self):
def rep(x):
form = "Pat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
def match(self:Pat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
(self.arg is not None and self.arg != uop.arg) or \
@@ -607,8 +607,8 @@ class Pat(MathTrait):
res.extend(stores)
return res
class PatAny(Pat):
def match(self:Pat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
class UPatAny(UPat):
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
ret = []
for x in self.src[0]:
if (match:=x.match(uop, store.copy())): ret.extend(match)
@@ -624,10 +624,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
class PatternMatcher:
def __init__(self, patterns:List[Tuple[Pat, Callable]]):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
self.patterns = patterns
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self.pdict: Dict[Tuple[Ops, Any], List[Tuple[Pat, Callable, Set, bool]]] = {}
self.pdict: Dict[Tuple[Ops, Any], List[Tuple[UPat, Callable, Set, bool]]] = {}
# uop is required, arg is optional
for p,fxn in self.patterns:
assert p.op is not None
@@ -652,12 +652,12 @@ class PatternMatcher:
# *** tracking pattern matcher ***
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
match_stats:Dict[Pat, List[Union[int, float]]] = dict()
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
@dataclass(frozen=True)
class TrackedRewriteContext:
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
matches: List[Tuple[UOp, Optional[UOp], Optional[Pat], float]] = field(default_factory=list) # all matches of sparents
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
@@ -676,7 +676,7 @@ def track_rewrites(named=False):
return _decorator
class TrackedPatternMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[Pat, Callable]]):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
super().__init__(patterns)
for p,_ in self.patterns:
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
@@ -745,84 +745,84 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
# this is the matcher for the final rendered UOps
# matcher functions returns True or False (or None to not match)
spec = PatternMatcher([
(Pat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
(Pat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
(Pat(Ops.DEFINE_ACC, src=(Pat.var("c"),), name="x", allow_any_len=True),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(Pat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(Pat(Ops.RANGE, src=(Pat(name="x"), Pat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
(Pat(Ops.SPECIAL, src=()), lambda: True),
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
(UPat(Ops.SPECIAL, src=()), lambda: True),
# TODO: confirm the args of both of these are shapetrackers
(Pat(Ops.VIEW, src=()), lambda: True),
(Pat(Ops.VIEW, src=(Pat(),)), lambda: True),
(UPat(Ops.VIEW, src=()), lambda: True),
(UPat(Ops.VIEW, src=(UPat(),)), lambda: True),
(Pat(Ops.VALID, dtypes.bool, (Pat(Ops.VIEW),)), lambda: True),
(Pat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
# early LOAD has a <buf, shapetracker, store?>
(Pat(Ops.LOAD, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW))), lambda: True),
(Pat(Ops.LOAD, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW), Pat(Ops.STORE))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
# early STORE has a <buf, shapetracker, val>
(Pat(Ops.STORE, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW), Pat())), lambda: True),
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
# **** new style load/store ****
# INDEX is used in new style load/store
(Pat(Ops.INDEX, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
# LOAD takes a <bufidx, alt?, gate?, barrier?>
(Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)),)), lambda: True),
(Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)), Pat((Ops.IF, Ops.BARRIER)))), lambda: True),
(Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(name="alt"), Pat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
# STORE takes a <bufidx, val, gate?>
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat())), lambda: True),
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(), Pat(dtype=dtypes.bool))), lambda: True),
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(), Pat(Ops.IF))), lambda: True),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(Pat(Ops.ALU, name="w", src=(Pat(dtype=dtypes.bool), Pat(name="x"), Pat(name="y")), arg=TernaryOps.WHERE),
(UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
lambda w,x,y: w.dtype == x.dtype == y.dtype),
(Pat(Ops.ALU, dtype=dtypes.bool, src=(Pat(name="x"), Pat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
(Pat(Ops.ALU, dtype=dtypes.bool, src=(Pat(name="x"), Pat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
# and SHL/SHR, the shift distance is an int
(Pat(Ops.ALU, src=(Pat(name="x"), Pat(name="y")), name="alu", arg=BinaryOps.SHL),
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(Pat(Ops.ALU, src=(Pat(name="x"), Pat(name="y")), name="alu", arg=BinaryOps.SHR),
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(Pat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(Pat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(Pat(Ops.ASSIGN, src=(Pat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), Pat())), lambda: True),
(Pat(Ops.ENDRANGE, dtype=dtypes.void, src=(Pat(Ops.RANGE),)), lambda: True),
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
# all WMMA has 3 args, <x, w, acc>
(Pat(Ops.WMMA, src=(Pat(), Pat(), Pat())), lambda: True),
(Pat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(Pat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier?>
(Pat(Ops.IF, dtype=dtypes.void, src=(Pat(),)), lambda: True),
(Pat(Ops.IF, dtype=dtypes.void, src=(Pat(), Pat(Ops.BARRIER))), lambda: True),
(Pat(Ops.ENDIF, dtype=dtypes.void, src=(Pat(Ops.IF),)), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
(Pat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
(Pat(Ops.GEP, src=(Pat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(Pat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(Pat((Ops.BITCAST, Ops.CAST), src=(Pat(),), name="x"), lambda x: x.arg is None),
(Pat(Ops.BARRIER, dtypes.void, src=Pat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
# NOTE: for testing, we let sinks be anything
#(Pat(UOps.SINK, src=Pat(UOps.STORE)), lambda: True),
(Pat(Ops.SINK, dtypes.void), lambda: True),
(Pat(Ops.NOOP), lambda: True),
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat(Ops.NOOP), lambda: True),
# PTX LOAD/STORE
(Pat((Ops.LOAD, Ops.STORE), src=(Pat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
(Pat(Ops.BARRIER, dtypes.void, src=Pat(Ops.STORE, src=(Pat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
])
def type_verify(uops:List[UOp]):
@@ -1013,125 +1013,125 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
symbolic_simple = PatternMatcher([
# ** self folding **
(Pat.var("x") + 0, lambda x: x), # x+0 -> x
(Pat.var("x") * 1, lambda x: x), # x*1 -> x
(Pat.var("x") // Pat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
(Pat.var("x") // 1, lambda x: x), # x//1 -> x
(Pat.var("x") // -1, lambda x: -x), # x//-1 -> -x
(Pat.var("x") / Pat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
((Pat.var("x") * Pat.var("x2")) / Pat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
((Pat.var() % Pat.var("y")).named("base") % Pat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
(Pat.var("x")%Pat.cvar("c")+(Pat.var("x")//Pat.cvar("c"))*Pat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
(Pat.var("x", dtype=dtypes.bool) & Pat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(Pat.var("x", dtype=dtypes.bool) | Pat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(Pat.var("x").maximum(Pat.var("x")), lambda x: x),
((Pat.var("x") & Pat.var("x")), lambda x: x),
((Pat.var("x") | Pat.var("x")), lambda x: x),
(Pat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(UPat.var("x").maximum(UPat.var("x")), lambda x: x),
((UPat.var("x") & UPat.var("x")), lambda x: x),
((UPat.var("x") | UPat.var("x")), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
# ** zero folding **
(Pat.var("x") < Pat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
(Pat.var("x", dtype=dtypes.ints) != Pat.var("x", dtype=dtypes.ints),
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN
(Pat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# ** constant folding **
(Pat(Ops.ALU, name="root", src=Pat((Ops.VCONST, Ops.CONST))),
(UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(Pat.var('x', dtype=dtypes.bool) * Pat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(Pat.var('x', dtype=dtypes.bool) + Pat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
(Pat.var('x', dtype=dtypes.bool).maximum(Pat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
# *** cast ***
(Pat(Ops.CAST, name="root", src=Pat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(Pat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
])
symbolic = symbolic_simple+PatternMatcher([
# ** COMMUTATIVE flipping **
*[(Pat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
*[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
# group like
((Pat.var("x") + Pat.var("y")) + Pat.var("x") * Pat.cvar("c"), lambda x,y,c: (x+x*c)+y),
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** combine terms **
(Pat.var("x") * Pat.cvar("c0") + Pat.var("x") * Pat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(Pat.var("x") + Pat.var("x") * Pat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
(Pat.var("x") + Pat.var("x"), lambda x: x*2), # (x+x)-> x*2
((Pat.var("x") / Pat.var("x2")) / Pat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (Pat.var("x") + Pat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
# a conditional with the same results either way is a noop, also fold const conditionals
(Pat.var().where(Pat.var("val"), Pat.var("val")), lambda val: val),
(Pat.cvar("gate", vec=False).where(Pat.var("c0"), Pat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# ALU min==max -> CONST (slow!)
(Pat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(Pat.maximum(Pat.var("x"), Pat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# TODO: why does this rule break beautiful_mnist?
#((Pat.var("x")+Pat.var("z")).maximum(Pat.var("y")+Pat.var("z")), lambda x,y,z: x.maximum(y) + z),
((Pat.var("x")*Pat.cvar("c1")).maximum(Pat.var("x")*Pat.cvar("c2")), max_var_const),
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
# ** two stage ALU folding **
((Pat.var("x") + Pat.cvar("c1")) + Pat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
((Pat.var("x") * Pat.cvar("c1")) * Pat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
((Pat.var("x") & Pat.cvar("c1")) & Pat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
((Pat.var("x") | Pat.cvar("c1")) | Pat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
((Pat.cvar("c0") + Pat.var("x")) < Pat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
((Pat.var("x") // Pat.cvar("c1")) // Pat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
# ** lt **
# c0*x<c1 for positive int c0,c1
((Pat.cvar("c0", vec=False)*Pat.var("x", dtype=dtypes.ints)).lt(Pat.cvar("c1", vec=False)),
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1
((Pat.cvar("c0", vec=False)*Pat.var("x", dtype=dtypes.ints)).lt(Pat.cvar("c1", vec=False)),
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# x//c0<c1 for positive int c0
((Pat.var("x", dtype=dtypes.ints)//Pat.cvar("c0", vec=False)).lt(Pat.cvar("c1", vec=False)),
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
# mul add lt
(((Pat.cvar("c0", vec=False)*Pat.var("x"))+Pat.var("x2")).lt(Pat.cvar("c1", vec=False)),
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
(Pat(Ops.ALU, arg=BinaryOps.ADD, src=(Pat.var("x"), Pat.cvar("c1"))) + Pat.var("y"), lambda x,c1,y: (x+y)+c1),
(Pat(Ops.ALU, arg=BinaryOps.MUL, src=(Pat.var("x"), Pat.cvar("c1"))) * Pat.var("y"), lambda x,c1,y: (x*y)*c1),
(UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
# *** rules from symbolic ***
# unrolled arange div folding
(Pat(Ops.ALU, name="divs", src=[Pat(), Pat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
(UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
# generic lt folding
(Pat.var("x", dtypes.sints).lt(Pat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
# canonicalize a simplex with positive coefficients > 0
# not x < 1 -> X > 0
(Pat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
(UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
# ** div **
# # div folding
(Pat.var("x", dtypes.sints) // Pat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
(UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
# ** mod **
# mod folding
(Pat.var("x") % Pat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
])
symbolic_flat = symbolic+PatternMatcher([
# ** combine terms (opinionated) **
(-1 * (Pat.var("x") + Pat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((Pat.var("x", dtypes.ints) + Pat.var("y")) * Pat.cvar("c"), lambda x,y,c: x*c+y*c),
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
])
_substitute = PatternMatcher([(Pat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
# for debug
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
renderer = PatternMatcher([
(Pat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
(Pat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
(Pat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(Pat(Ops.BIND, src=Pat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
])
# *** what was symbolic.py ***

View File

@@ -2,66 +2,66 @@ from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
import os, math
from collections import defaultdict, Counter
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, Pat, cast_float_to_bf16
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
base_rewrite = PatternMatcher([
(Pat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
(Pat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
(Pat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(Pat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
(Pat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
# r method accesses
(Pat(Ops.RANGE, name="x"),
(UPat(Ops.RANGE, name="x"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
(Pat(Ops.VECTORIZE, name="x"),
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
(Pat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
(Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
(Pat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
(Pat(Ops.BARRIER), lambda ctx: ctx.barrier),
(Pat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
(Pat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
# const
(Pat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
(Pat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
(Pat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
(Pat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
(Pat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
(Pat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
(Pat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(Pat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
(UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(Pat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
(Pat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
(Pat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
# default const render
(Pat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var('idx'))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"),
(Pat(Ops.LOAD, src=(Pat.var('bidx'), Pat.var("var"), Pat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(Pat(Ops.LOAD, src=(Pat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(Pat(Ops.STORE, src=(Pat.var('bidx'), Pat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
(Pat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
(UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
*([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
(Pat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
])
extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(Pat(Ops.BITCAST, name="x"),
(UPat(Ops.BITCAST, name="x"),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
# gate any stores that aren't gated with ifs
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat(), Pat(), Pat(dtype=dtypes.bool)), name="store"),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(Pat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
(UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
@@ -214,13 +214,13 @@ class OpenCLRenderer(CStyleLanguage):
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
string_rewrite = PatternMatcher([
(Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
# load/store image (OpenCL)
(Pat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))), Pat.var("var"), Pat.var("gate"))),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
(Pat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))),)),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
(Pat(Ops.STORE, src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))), Pat.var("var", dtypes.float.vec(4))), allow_any_len=True),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
]) + base_rewrite
@@ -234,8 +234,8 @@ class IntelRenderer(OpenCLRenderer):
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
string_rewrite = PatternMatcher([
(Pat(Ops.CAST, dtype=dtypes.bfloat16, src=(Pat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
(Pat(Ops.CAST, dtype=dtypes.float, src=(Pat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
]) + OpenCLRenderer.string_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
@@ -272,13 +272,13 @@ class MetalRenderer(CStyleLanguage):
# upcast to float32 all the ops that don't support bfloat16
extra_matcher = PatternMatcher([
# NOTE: this is copied from PTX
*[(Pat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
]) + extra_pm
string_rewrite = PatternMatcher([
(Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
@@ -387,20 +387,20 @@ class AMDRenderer(CStyleLanguage):
type_map = {dtypes.bfloat16: "hip_bfloat16"}
extra_matcher = PatternMatcher([
# cast bfloat16 alus to float
(Pat(Ops.ALU, arg=TernaryOps.WHERE, src=(Pat.var("b"), Pat.var("x", dtype=dtypes.bfloat16), Pat.var("y", dtype=dtypes.bfloat16))),
(UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
(Pat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
(UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
(Pat(Ops.ALU, dtypes.bool, name="alu", src=(Pat.var("x", dtype=dtypes.bfloat16), Pat.var("y", dtype=dtypes.bfloat16))),
(UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
# add float intermediate casting for bfloat16
(Pat(Ops.CAST, name="x", src=Pat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(Pat(Ops.CAST, dtypes.bfloat16, Pat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
# bfloat16 casting
(Pat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
(Pat(Ops.CAST, dtype=dtypes.float, src=Pat.var("x", dtype=dtypes.bfloat16)),
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
(UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
(Pat(Ops.CAST, dtype=dtypes.bfloat16, src=Pat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
def render_vector_prefix(self, dtype:DType) -> str:
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())

View File

@@ -1,7 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, Pat
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -35,23 +35,23 @@ asm_for_op: Dict[Op, Callable] = {
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(Pat.var('x', dtype=dtypes.bool).ne(Pat.var('y')), lambda x,y: x^y),
(Pat.var('x', dtype=dtypes.bool).lt(Pat.var('y')), lambda x,y: (x^True)&y),
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
# upcast to float32 all the ops that don't support half
*[(Pat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
for op in asm_for_op.keys() if op not in supports_half],
# load/store bool -> uint8
(Pat(Ops.LOAD, dtypes.bool, src=(Pat(dtype=dtypes.int64),), name="x", allow_any_len=True),
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
(Pat(Ops.STORE, src=(Pat(dtype=dtypes.int64), Pat(dtype=dtypes.bool)), name="x", allow_any_len=True),
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
# load/store use pointer arithmetic, and the cast does nothing
(Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
(Pat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
# ptx shr and shl instructions require y to be uint
(Pat.var("x") << Pat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
(Pat.var("x") >> Pat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
])
class PTXRenderer(Renderer):

View File

@@ -25,7 +25,7 @@ class GraphRewriteMetadata:
kernel_name: Optional[str]
"""The kernel calling graph_rewrite"""
upats: List[Tuple[Tuple[str, int], str, float]]
"""List of all the applied Pats"""
"""List of all the applied UPats"""
@dataclass
class GraphRewriteDetails(GraphRewriteMetadata):