mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
* Revert "s/UPat/Pat (#7506)"
This reverts commit 400011a8c1.
* fix
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest, pickle, types
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.ops import PatternMatcher, Pat, UOp
|
||||
from tinygrad.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_code_object(self):
|
||||
@@ -12,7 +12,7 @@ class TestPickle(unittest.TestCase):
|
||||
self.assertEqual(fxn(2), 4)
|
||||
|
||||
def test_pickle_pattern_matcher(self):
|
||||
pm = PatternMatcher([(Pat.cvar('x'), lambda x: x*2)])
|
||||
pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
|
||||
sink = UOp.const(dtypes.int, 2)
|
||||
tt = pm.rewrite(sink)
|
||||
pm_str = pickle.dumps(pm)
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, time
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo
|
||||
from tinygrad.ops import Pat, PatternMatcher
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym
|
||||
@@ -11,10 +11,10 @@ from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
(Pat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
(Pat.cvar('x') + Pat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
|
||||
(Pat.cvar('x') * Pat.cvar('y') * Pat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
|
||||
((Pat.var('x') + Pat.cvar('c1')) + Pat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
|
||||
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
|
||||
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
])
|
||||
|
||||
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, Pat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule, to_si
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
@@ -441,13 +441,13 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||
stores = [st for st in uops if st.op is Ops.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
class TestPatHelpers(unittest.TestCase):
|
||||
class TestUPatHelpers(unittest.TestCase):
|
||||
def test_location(self):
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
|
||||
with self.assertRaises(AssertionError): # TODO: location Pat files created in test/*?
|
||||
test_upat = Pat(Ops.CONST, dtypes.bool)
|
||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||
test_upat = UPat(Ops.CONST, dtypes.bool)
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, Pat, \
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
|
||||
graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
def test_viz_simple(self):
|
||||
pm = PatternMatcher([
|
||||
(Pat.var("x")*1, lambda x:x),
|
||||
(UPat.var("x")*1, lambda x:x),
|
||||
])
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a*1, pm)
|
||||
@@ -34,8 +34,8 @@ class TestViz(unittest.TestCase):
|
||||
|
||||
def test_rewrite_twice(self):
|
||||
pm = PatternMatcher([
|
||||
(Pat.var("x")+Pat.var("x"), lambda x:x*2),
|
||||
(Pat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||
])
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a+a, pm)
|
||||
@@ -51,14 +51,14 @@ class TestViz(unittest.TestCase):
|
||||
ctx[x] = None
|
||||
return UOp.store(*x.src, x)
|
||||
pm = PatternMatcher([
|
||||
(Pat(Ops.LOAD, name="x"), store_load),
|
||||
(UPat(Ops.LOAD, name="x"), store_load),
|
||||
])
|
||||
uops = helper_test_viz(a+b, pm, {})
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
|
||||
|
||||
def test_track_rewrites(self):
|
||||
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites(named=True)
|
||||
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
@@ -74,7 +74,7 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(m.upats), 0)
|
||||
|
||||
def test_track_rewrites_with_exception(self):
|
||||
simple = PatternMatcher([(Pat.var("x")*1, lambda x:x)])
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites()
|
||||
def do_rewrite(x:UOp):
|
||||
x = graph_rewrite(x, simple) # NOTE: viz tracks this
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import unittest, itertools
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.ops import PatternMatcher, Pat
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
def test_simple_match(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
@@ -16,7 +16,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
#print(x,y,z)
|
||||
if y is not None: return a+y
|
||||
matcher = PatternMatcher([
|
||||
(Pat.var("a")+Pat.any(Pat.var("x"), Pat.var("y"), Pat.var("z")), test),
|
||||
(UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test),
|
||||
])
|
||||
v1 = UOp.variable("a", 0, 10)
|
||||
v2 = UOp.variable("b", 0, 10)
|
||||
@@ -31,7 +31,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
match_cnt += 1
|
||||
assert len(x.src) == 0
|
||||
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
c1 = matcher.rewrite(c1)
|
||||
@@ -43,7 +43,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
ctx.append(True)
|
||||
assert len(x.src) == 0
|
||||
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
ctx = []
|
||||
@@ -52,14 +52,14 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(len(ctx), 1)
|
||||
|
||||
def test_uop(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x"), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_uop_set(self):
|
||||
matcher = PatternMatcher([(Pat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
c2 = UOp(Ops.CAST, dtypes.int, (c1,))
|
||||
c3 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
@@ -70,9 +70,9 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(Pat(Ops.CONST, arg=0, name="x"), lambda x: x),
|
||||
(Pat(Ops.CONST, arg=False, name="x"), lambda x: x),
|
||||
(Pat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
|
||||
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
@@ -87,7 +87,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_filter_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(Pat(Ops.ALU, arg=BinaryOps.MUL, src=[Pat(Ops.CONST, name="c"), Pat(Ops.CONST, arg=2)], name="x"),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
|
||||
lambda x,c: x if c.arg in {1, -1} else None)
|
||||
])
|
||||
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
@@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c5), c5)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST, name="y"), Pat(Ops.CONST, name="y"))), lambda x, y: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
|
||||
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
@@ -114,14 +114,14 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c2), c1)
|
||||
|
||||
def test_dtype(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_dtype_set(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
|
||||
c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0)
|
||||
@@ -132,7 +132,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_src_one(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST), Pat(Ops.CONST))), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -140,7 +140,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
# that CONST/ALU -> ALU/CONST rewrite is now instant
|
||||
"""
|
||||
matcher = PatternMatcher([(Pat(UOps.ALU, name="x", src=(Pat(UOps.CONST), Pat(UOps.ALU))), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
@@ -149,7 +149,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_src_permutations(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=[Pat(Ops.CONST), Pat(Ops.ALU)]), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -162,7 +162,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c6), None)
|
||||
|
||||
def test_src_repeat(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=Pat(Ops.CONST)), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -171,7 +171,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_allow_len(self):
|
||||
matcher = PatternMatcher([(Pat(Ops.ALU, name="x", src=(Pat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
|
||||
@@ -188,16 +188,16 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
u1 = (c1 + c2) + c1
|
||||
u2 = (c2 + c1) + c1
|
||||
matcher = PatternMatcher([
|
||||
(Pat(Ops.ALU, src=[Pat(Ops.ALU, src=[Pat(name='a'), Pat(name='b')]), Pat(name='b')]), lambda a,b: b)
|
||||
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
|
||||
])
|
||||
self.assertIsNotNone(matcher.rewrite(u1))
|
||||
self.assertIsNotNone(matcher.rewrite(u2))
|
||||
|
||||
def _assert_eq_upat(self, a:Pat, b:Pat):
|
||||
def _assert_eq_upat(self, a:UPat, b:UPat):
|
||||
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
|
||||
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
|
||||
assert (a.name, type(a.src)) == (b.name, type(b.src))
|
||||
def simple_src(u:Pat):
|
||||
def simple_src(u:UPat):
|
||||
if u.src is None: return []
|
||||
if isinstance(u.src, itertools.repeat): return next(u.src[0])
|
||||
return u.src[0]
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, Pat, sint, identity_element
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
|
||||
@@ -109,7 +109,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
# TODO: check has_valid in Pat, not here
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
||||
buf = x.src[0]
|
||||
if x.op is Ops.LOAD:
|
||||
@@ -127,10 +127,10 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(Pat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(Pat(Ops.VALID, src=(Pat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
||||
(Pat((Ops.LOAD, Ops.STORE), src=(Pat(), Pat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
])
|
||||
|
||||
def do_reduce(ctx:List[int], root:UOp):
|
||||
@@ -141,7 +141,7 @@ def do_reduce(ctx:List[int], root:UOp):
|
||||
|
||||
just_reduce = PatternMatcher([
|
||||
# do reduce
|
||||
(Pat(Ops.REDUCE, name="root"), do_reduce),
|
||||
(UPat(Ops.REDUCE, name="root"), do_reduce),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict,
|
||||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, Pat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
@@ -77,10 +77,10 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
||||
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
|
||||
|
||||
buf_idx_pat = Pat(Ops.INDEX, src=(Pat.var("buf"),), allow_any_len=True)
|
||||
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
|
||||
float4_folding = PatternMatcher([
|
||||
(Pat(Ops.VECTORIZE, src=Pat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(Pat((Ops.BARRIER, Ops.SINK), src=Pat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
])
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
@@ -124,24 +124,24 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.lru_cache(None)
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: List[Tuple[Pat, Callable]] = [(Pat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(Pat.var("d"),), arg=op), f) for op,f in \
|
||||
pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
|
||||
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
||||
if BinaryOps.AND in ops:
|
||||
pat += [(Pat(Ops.ALU, arg=BinaryOps.MOD, src=(Pat.var('base'), Pat.cvar("const"))),
|
||||
pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
|
||||
# rewrite MUL/IDIV to SHL+SHR
|
||||
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
|
||||
pat += [
|
||||
(Pat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[Pat.cvar("const"), Pat.var("mul")]), lambda mul, const:
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
|
||||
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
|
||||
(Pat(Ops.ALU, arg=BinaryOps.IDIV, src=(Pat.var("div"), Pat.cvar("const"))), lambda div, const:
|
||||
(UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
|
||||
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
|
||||
if UnaryOps.NEG in ops:
|
||||
pat += [(Pat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
|
||||
if BinaryOps.SUB in ops: pat += [(Pat.var('x')+Pat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
|
||||
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
|
||||
if TernaryOps.MULACC in ops:
|
||||
pat += [(Pat.var('a')*Pat.var('b')+Pat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
|
||||
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
# ***** threefry *****
|
||||
@@ -231,79 +231,79 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
acc_pat, rng_pat = Pat(Ops.DEFINE_ACC, name="acc"), Pat(Ops.RANGE, name="rng")
|
||||
rng_aug = Pat.any(rng_pat, Pat.var("add")+rng_pat, Pat.var("mul")*rng_pat, Pat.var("add")+Pat.var("mul")*rng_pat)
|
||||
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
|
||||
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
|
||||
|
||||
index_load = Pat.var("buf").index(rng_aug).load(name="ld")
|
||||
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
||||
|
||||
arange_augrng = Pat.any(rng_aug, rng_aug+Pat.var("idx2"), rng_aug+Pat.var("idx2")+Pat.var("idx3"), Pat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = arange_augrng.lt(Pat.cvar("compval")).ne(Pat(Ops.CONST, name="ne", arg=True)).where(Pat.cvar("multconst"), Pat.const(None, 0))
|
||||
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
||||
|
||||
# this is symbolic 2.0
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
(Pat(Ops.ASSIGN, src=(Pat.var('x'), Pat.var('x'))), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
(Pat(Ops.ASSIGN, src=(Pat(Ops.DEFINE_GLOBAL), Pat.var("x"))), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
|
||||
# VECTORIZE/CONST, VECTORIZE/GEP
|
||||
(Pat(Ops.VECTORIZE, src=Pat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(Pat(Ops.VECTORIZE, src=Pat(Ops.GEP, src=(Pat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
# reorder ALU/VECTORIZE
|
||||
(Pat(Ops.ALU, src=(Pat(Ops.VECTORIZE, src=Pat(name='x')), Pat(Ops.VECTORIZE, src=Pat(name='y'))), name='alu'),
|
||||
(UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
||||
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
|
||||
# VECTORIZE of a single element is just that element
|
||||
(Pat(Ops.VECTORIZE, src=(Pat(name='x'),)), lambda x: x),
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# VECTORIZE void is SINK
|
||||
(Pat(Ops.VECTORIZE, dtype=dtypes.void, src=Pat(Ops.BARRIER, name='b')), lambda b: b),
|
||||
(Pat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(Pat(Ops.GEP, src=(Pat(Ops.GEP, name='g2'),), name='g1'),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(Pat(Ops.GEP, src=(Pat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(Pat(Ops.GEP, src=(Pat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(Pat(Ops.GEP, src=(Pat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(Pat(Ops.GEP, src=(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
(UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
|
||||
# push some GEPs through WMMAs
|
||||
(Pat(Ops.GEP, src=(Pat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
# tensor core with a 0 input is acc
|
||||
(Pat(Ops.WMMA, src=(Pat.const(None, 0.0), Pat.var(), Pat.var("acc"))), lambda acc: acc),
|
||||
(Pat(Ops.WMMA, src=(Pat.var(), Pat.const(None, 0.0), Pat.var("acc"))), lambda acc: acc),
|
||||
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
||||
# tensor core cleanups
|
||||
(Pat.var("add") + Pat(Ops.WMMA, name="wmma"),
|
||||
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
(Pat(Ops.ALU, dtype=dtypes.uint64, src=(Pat.var("x"), Pat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
(UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# arange loop folding
|
||||
(acc_pat.assign(Pat.any(arange_m, arange_m+Pat.var("extra"))+acc_pat), loop_collapse),
|
||||
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(Pat.var("idx").eq(Pat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(Pat.var("idx").eq(Pat(Ops.RANGE, name="rng")).where(index_load, Pat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
# parentless reduce
|
||||
(acc_pat.assign(Pat(Ops.ALU, src=[acc_pat, Pat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(Pat(Ops.ALU, src=[acc_pat, Pat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
|
||||
# ** self folding **
|
||||
(Pat(Ops.DEFINE_ACC, src=(Pat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(Pat(Ops.ASSIGN, src=(Pat.cvar(),Pat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
# x!=0 -> (bool)x
|
||||
(Pat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
||||
(UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
||||
# ** load/store folding **
|
||||
(Pat.store(Pat(Ops.INDEX, name="index"), Pat.load(Pat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
||||
(Pat.store(Pat(Ops.INDEX, name="index"), Pat.var("gate").where(Pat.var("alt"), Pat.load(Pat(Ops.INDEX, name="index")))),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
|
||||
lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
|
||||
# fold gated LOAD/STORE
|
||||
(Pat().index(Pat(), Pat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
||||
(Pat().index(Pat(), Pat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
|
||||
(Pat(Ops.LOAD, src=(Pat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
|
||||
(Pat(Ops.STORE, src=(Pat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
|
||||
(UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
|
||||
(UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
|
||||
# remove NOOPs from SINK
|
||||
(Pat(Ops.SINK, name="root"),
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
|
||||
# remove EXPANDs from SINK/BARRIER
|
||||
(Pat(Ops.BARRIER, src=(Pat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
||||
(Pat(Ops.SINK, name="root"),
|
||||
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.EXPAND} else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in {Ops.SINK, Ops.EXPAND} for x in root.src) else None),
|
||||
])
|
||||
@@ -400,21 +400,21 @@ def create_gate(root:UOp) -> Optional[UOp]:
|
||||
|
||||
expander = PatternMatcher([
|
||||
# double expand
|
||||
(Pat(Ops.EXPAND, name="outer", src=(Pat(Ops.EXPAND, name="inner"),)),
|
||||
(UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand),
|
||||
(Pat(Ops.CONTRACT, name="con"), do_contract),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# vectorize DEFINE_ACC
|
||||
(Pat(Ops.VECTORIZE, src=Pat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
||||
# BARRIERs aren't actually expanded
|
||||
(Pat(Ops.BARRIER, src=(Pat(Ops.EXPAND, name="ex"),)),
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.EXPAND, name="ex"),)),
|
||||
lambda ex: UOp(Ops.EXPAND, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
|
||||
# empty EXPAND is NOOP
|
||||
(Pat(Ops.EXPAND, src=(Pat.var('x'),), arg=()), lambda x: x),
|
||||
(UPat(Ops.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||
# EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
|
||||
(Pat(Ops.EXPAND, name="ex", src=tuple(Pat.var('x').gep(i)+Pat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
(UPat(Ops.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
lambda ex,x,y: UOp(Ops.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
])
|
||||
|
||||
@@ -433,10 +433,10 @@ def no_vectorized_acc(acc:UOp):
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(Pat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(Pat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
||||
(Pat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
||||
])
|
||||
|
||||
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]:
|
||||
@@ -446,14 +446,14 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optio
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
# late fixup of unfoldable image loads
|
||||
(Pat(Ops.LOAD, src=(Pat.var("buf"), Pat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
# simplify valid
|
||||
(Pat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
|
||||
(UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
(Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var("start_idx"), Pat.var("valid"))), simplify_valid_load),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# delete_redundant_gates (after expand)
|
||||
(Pat(Ops.STORE, src=(Pat.any(stidx:=Pat.var("buf").index(Pat.var("idx"), Pat.var("store_gate")), stidx.cast().named("cast")),
|
||||
Pat.var("val"))), delete_redundant_gates),
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val"))), delete_redundant_gates),
|
||||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
@@ -466,9 +466,9 @@ def idx_load_store(x:UOp):
|
||||
|
||||
migrate_indexing = PatternMatcher([
|
||||
# use indexing for LOAD/STORE
|
||||
(Pat((Ops.LOAD, Ops.STORE), src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
# create gate MUST BE BEFORE expander
|
||||
(Pat(Ops.STORE, name="root"), create_gate),
|
||||
(UPat(Ops.STORE, name="root"), create_gate),
|
||||
])
|
||||
|
||||
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp:
|
||||
@@ -478,13 +478,13 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
# for rendering, we use explicit VECTORIZE
|
||||
(Pat(Ops.CONST, name='c'),
|
||||
(UPat(Ops.CONST, name='c'),
|
||||
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(Pat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(Pat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(Pat(Ops.VECTORIZE, src=(Pat(name='x'),)), lambda x: x),
|
||||
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# move masks of loads/stores
|
||||
(Pat((Ops.LOAD, Ops.STORE), src=(Pat.any(masked_index:=Pat(Ops.INDEX, src=(Pat(name="buf"), Pat(name="idx"), Pat(name="mask"))),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
|
||||
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
||||
])
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, Pat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.ops import BUFOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -138,27 +138,27 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
||||
|
||||
merge_views = PatternMatcher([(Pat(Ops.VIEW, src=(Pat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
||||
merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
||||
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before ALU
|
||||
(Pat(Ops.VIEW, src=(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"),
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFOPS), name="e"),), name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
])
|
||||
|
||||
# push VIEW to stores
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# ASSIGN can override st
|
||||
(Pat(Ops.STORE, src=(Pat.var("b"), Pat.var("st"), Pat(Ops.ASSIGN, name="a"))),
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
|
||||
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
|
||||
# VIEW on a reduce creates a new VIEW
|
||||
(Pat(Ops.VIEW, src=(Pat(Ops.REDUCE_AXIS, src=Pat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
||||
(Pat(Ops.REDUCE_AXIS, src=(Pat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
(Pat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(Pat(Ops.REDUCE_AXIS, src=(Pat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
# ** ScheduleItem context builder
|
||||
@@ -181,26 +181,26 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(Pat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
if b in ctx.assigned: ctx.assign_preloads.append(b)
|
||||
return x.replace(op=Ops.LOAD)
|
||||
|
||||
to_si = PatternMatcher([
|
||||
(Pat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(Pat(Ops.PRELOAD, src=(Pat.var("b"), Pat()), name="x"), _append_preload),
|
||||
(Pat(Ops.CONTIGUOUS, src=(Pat.var("x"),)), lambda ctx,x: x),
|
||||
(Pat(Ops.SINK, src=(Pat.store(Pat(), Pat(), Pat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
|
||||
])
|
||||
|
||||
# ** fusion
|
||||
|
||||
lazy = PatternMatcher([
|
||||
(Pat.load(b:=Pat.var("b"), Pat(), Pat.store(b, Pat(), Pat.var("v"))), lambda ctx,b,v: v),
|
||||
(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v),
|
||||
])
|
||||
|
||||
multioutput = PatternMatcher([(Pat.load(Pat.var("b"), Pat()), lambda ctx,b: ctx.get(b)),])
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
|
||||
|
||||
def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -> Tuple[UOp, ScheduleItemContext]:
|
||||
# fuse and fold store -> loads
|
||||
@@ -234,14 +234,14 @@ def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
|
||||
ctx[b] = store
|
||||
return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop()))
|
||||
|
||||
def PatLoadStore(to_store=Pat()): return Pat.load(b:=Pat.var("b"), Pat(), Pat.store(b, Pat(), to_store, name="store"), name="load")
|
||||
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(PatLoadStore(Pat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize),
|
||||
(Pat((Ops.COPY, Ops.BUFFER_VIEW), src=(Pat.var("u"), Pat.any(PatLoadStore(), PatLoadStore().view(name="v"))), name="root"),
|
||||
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize),
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
|
||||
lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),)
|
||||
])
|
||||
break_sched = PatternMatcher([(PatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
|
||||
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
|
||||
292
tinygrad/ops.py
292
tinygrad/ops.py
@@ -205,7 +205,7 @@ def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else ls
|
||||
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
||||
# used for UOp and Pat
|
||||
# used for UOp and UPat
|
||||
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
||||
def dfs(x:Any, cache:dict):
|
||||
for s in srcfn(x) or []:
|
||||
@@ -514,7 +514,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
|
||||
def get_location() -> Tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# find the real frame in the file that has the Pat, TODO: is there a better way to do this?
|
||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py",
|
||||
"lowerer.py", "cstyle.py"}:
|
||||
frm = frm.f_back
|
||||
@@ -523,75 +523,75 @@ def get_location() -> Tuple[str, int]:
|
||||
def lines(fn) -> List[str]:
|
||||
with open(fn) as f: return f.readlines()
|
||||
|
||||
class Pat(MathTrait):
|
||||
class UPat(MathTrait):
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
||||
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[Pat, ...], List[Pat], Pat]]=None, arg:Any=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None,
|
||||
custom_early_reject:Optional[Set[Tuple[Ops, Any]]]=None):
|
||||
self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else op
|
||||
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
self.src: Any = None
|
||||
assert self.name != "ctx", "Pat can't be named ctx"
|
||||
assert self.name != "ctx", "UPat can't be named ctx"
|
||||
|
||||
# try all permutations if it's a list
|
||||
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
|
||||
# only one if it's a tuple
|
||||
elif isinstance(src, tuple): self.src = [src]
|
||||
# repeat if it's a Pat
|
||||
elif isinstance(src, Pat): self.src = [itertools.repeat(src)]
|
||||
# repeat if it's a UPat
|
||||
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
||||
|
||||
self.allowed_len: int = -1 if allow_any_len or isinstance(src, Pat) or src is None else len(src)
|
||||
self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
||||
self.location = location or get_location()
|
||||
|
||||
if custom_early_reject is not None: self.early_reject = custom_early_reject
|
||||
else:
|
||||
upat_match = [src] if isinstance(src, Pat) else ([] if src is None else self.src[0])
|
||||
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
|
||||
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
|
||||
def named(self, name:str): return Pat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
|
||||
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
|
||||
|
||||
@staticmethod
|
||||
def any(*src): return PatAny(src=src)
|
||||
def any(*src): return UPatAny(src=src)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return Pat(dtype=dtype, name=name)
|
||||
def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
|
||||
return Pat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
|
||||
return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return Pat(Ops.CONST, dtype=dtype, arg=b)
|
||||
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# copied from UOp
|
||||
def index(self, idx:Pat, valid:Optional[Pat]=None): return Pat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return Pat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None): return Pat(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return Pat(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return Pat(Ops.GEP, None, (self,), (i,))
|
||||
def load(self, *src:Pat, **kwargs): return Pat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:Pat, **kwargs): return Pat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:Pat): return Pat(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
|
||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
def const_like(self, b:ConstLike): return Pat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, arg, *src:Pat):
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, arg, *src:UPat):
|
||||
asrc = (self,)+src
|
||||
return Pat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
|
||||
return UPat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
|
||||
|
||||
def printable(self:Pat) -> str:
|
||||
def printable(self:UPat) -> str:
|
||||
try: return lines(self.location[0])[self.location[1]-1].strip()
|
||||
except FileNotFoundError: return "<missing>"
|
||||
|
||||
def __repr__(self):
|
||||
def rep(x):
|
||||
form = "Pat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
||||
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
||||
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
|
||||
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
||||
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
||||
|
||||
def match(self:Pat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
||||
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
||||
(self.arg is not None and self.arg != uop.arg) or \
|
||||
@@ -607,8 +607,8 @@ class Pat(MathTrait):
|
||||
res.extend(stores)
|
||||
return res
|
||||
|
||||
class PatAny(Pat):
|
||||
def match(self:Pat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
class UPatAny(UPat):
|
||||
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
ret = []
|
||||
for x in self.src[0]:
|
||||
if (match:=x.match(uop, store.copy())): ret.extend(match)
|
||||
@@ -624,10 +624,10 @@ def deconstruct_function(fxn:Callable) -> Tuple:
|
||||
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[Pat, Callable]]):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
||||
self.pdict: Dict[Tuple[Ops, Any], List[Tuple[Pat, Callable, Set, bool]]] = {}
|
||||
self.pdict: Dict[Tuple[Ops, Any], List[Tuple[UPat, Callable, Set, bool]]] = {}
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
@@ -652,12 +652,12 @@ class PatternMatcher:
|
||||
# *** tracking pattern matcher ***
|
||||
|
||||
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
||||
match_stats:Dict[Pat, List[Union[int, float]]] = dict()
|
||||
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
@dataclass(frozen=True)
|
||||
class TrackedRewriteContext:
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
matches: List[Tuple[UOp, Optional[UOp], Optional[Pat], float]] = field(default_factory=list) # all matches of sparents
|
||||
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
|
||||
|
||||
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
@@ -676,7 +676,7 @@ def track_rewrites(named=False):
|
||||
return _decorator
|
||||
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[Pat, Callable]]):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
super().__init__(patterns)
|
||||
for p,_ in self.patterns:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
@@ -745,84 +745,84 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([
|
||||
(Pat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(Pat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(Pat(Ops.DEFINE_ACC, src=(Pat.var("c"),), name="x", allow_any_len=True),
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
||||
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(Pat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(Pat(Ops.RANGE, src=(Pat(name="x"), Pat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(Pat(Ops.SPECIAL, src=()), lambda: True),
|
||||
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
(Pat(Ops.VIEW, src=()), lambda: True),
|
||||
(Pat(Ops.VIEW, src=(Pat(),)), lambda: True),
|
||||
(UPat(Ops.VIEW, src=()), lambda: True),
|
||||
(UPat(Ops.VIEW, src=(UPat(),)), lambda: True),
|
||||
|
||||
(Pat(Ops.VALID, dtypes.bool, (Pat(Ops.VIEW),)), lambda: True),
|
||||
(Pat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
|
||||
# early LOAD has a <buf, shapetracker, store?>
|
||||
(Pat(Ops.LOAD, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW))), lambda: True),
|
||||
(Pat(Ops.LOAD, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW), Pat(Ops.STORE))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
||||
|
||||
# early STORE has a <buf, shapetracker, val>
|
||||
(Pat(Ops.STORE, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat(Ops.VIEW), Pat())), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
||||
|
||||
# **** new style load/store ****
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
(Pat(Ops.INDEX, src=(Pat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), Pat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
||||
(Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
||||
(Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)), Pat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
||||
(Pat(Ops.LOAD, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(name="alt"), Pat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
||||
|
||||
# STORE takes a <bufidx, val, gate?>
|
||||
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat())), lambda: True),
|
||||
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(), Pat(dtype=dtypes.bool))), lambda: True),
|
||||
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat((Ops.INDEX, Ops.CAST)), Pat(), Pat(Ops.IF))), lambda: True),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(Pat(Ops.ALU, name="w", src=(Pat(dtype=dtypes.bool), Pat(name="x"), Pat(name="y")), arg=TernaryOps.WHERE),
|
||||
(UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
|
||||
lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(Pat(Ops.ALU, dtype=dtypes.bool, src=(Pat(name="x"), Pat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
|
||||
(Pat(Ops.ALU, dtype=dtypes.bool, src=(Pat(name="x"), Pat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
|
||||
# and SHL/SHR, the shift distance is an int
|
||||
(Pat(Ops.ALU, src=(Pat(name="x"), Pat(name="y")), name="alu", arg=BinaryOps.SHL),
|
||||
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
|
||||
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
|
||||
(Pat(Ops.ALU, src=(Pat(name="x"), Pat(name="y")), name="alu", arg=BinaryOps.SHR),
|
||||
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
|
||||
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
|
||||
(Pat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(Pat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
|
||||
(Pat(Ops.ASSIGN, src=(Pat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), Pat())), lambda: True),
|
||||
(Pat(Ops.ENDRANGE, dtype=dtypes.void, src=(Pat(Ops.RANGE),)), lambda: True),
|
||||
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
||||
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
|
||||
# all WMMA has 3 args, <x, w, acc>
|
||||
(Pat(Ops.WMMA, src=(Pat(), Pat(), Pat())), lambda: True),
|
||||
(Pat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(Pat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# if has a <gate, barrier?>
|
||||
(Pat(Ops.IF, dtype=dtypes.void, src=(Pat(),)), lambda: True),
|
||||
(Pat(Ops.IF, dtype=dtypes.void, src=(Pat(), Pat(Ops.BARRIER))), lambda: True),
|
||||
(Pat(Ops.ENDIF, dtype=dtypes.void, src=(Pat(Ops.IF),)), lambda: True),
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
|
||||
(Pat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
|
||||
(Pat(Ops.GEP, src=(Pat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
(Pat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(Pat((Ops.BITCAST, Ops.CAST), src=(Pat(),), name="x"), lambda x: x.arg is None),
|
||||
(Pat(Ops.BARRIER, dtypes.void, src=Pat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
|
||||
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
||||
|
||||
# NOTE: for testing, we let sinks be anything
|
||||
#(Pat(UOps.SINK, src=Pat(UOps.STORE)), lambda: True),
|
||||
(Pat(Ops.SINK, dtypes.void), lambda: True),
|
||||
(Pat(Ops.NOOP), lambda: True),
|
||||
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
||||
(UPat(Ops.NOOP), lambda: True),
|
||||
|
||||
# PTX LOAD/STORE
|
||||
(Pat((Ops.LOAD, Ops.STORE), src=(Pat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
(Pat(Ops.BARRIER, dtypes.void, src=Pat(Ops.STORE, src=(Pat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
||||
])
|
||||
|
||||
def type_verify(uops:List[UOp]):
|
||||
@@ -1013,125 +1013,125 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
||||
|
||||
symbolic_simple = PatternMatcher([
|
||||
# ** self folding **
|
||||
(Pat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(Pat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
(Pat.var("x") // Pat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(Pat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(Pat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
(Pat.var("x") / Pat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((Pat.var("x") * Pat.var("x2")) / Pat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
((Pat.var() % Pat.var("y")).named("base") % Pat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
||||
(Pat.var("x")%Pat.cvar("c")+(Pat.var("x")//Pat.cvar("c"))*Pat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
(Pat.var("x", dtype=dtypes.bool) & Pat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(Pat.var("x", dtype=dtypes.bool) | Pat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(Pat.var("x").maximum(Pat.var("x")), lambda x: x),
|
||||
((Pat.var("x") & Pat.var("x")), lambda x: x),
|
||||
((Pat.var("x") | Pat.var("x")), lambda x: x),
|
||||
(Pat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
||||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(UPat.var("x").maximum(UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
# ** zero folding **
|
||||
(Pat.var("x") < Pat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
(Pat.var("x", dtype=dtypes.ints) != Pat.var("x", dtype=dtypes.ints),
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
||||
lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(Pat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# ** constant folding **
|
||||
(Pat(Ops.ALU, name="root", src=Pat((Ops.VCONST, Ops.CONST))),
|
||||
(UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(Pat.var('x', dtype=dtypes.bool) * Pat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(Pat.var('x', dtype=dtypes.bool) + Pat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
(Pat.var('x', dtype=dtypes.bool).maximum(Pat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
||||
# *** cast ***
|
||||
(Pat(Ops.CAST, name="root", src=Pat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
(Pat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** COMMUTATIVE flipping **
|
||||
*[(Pat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
|
||||
*[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
|
||||
# group like
|
||||
((Pat.var("x") + Pat.var("y")) + Pat.var("x") * Pat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
# ** combine terms **
|
||||
(Pat.var("x") * Pat.cvar("c0") + Pat.var("x") * Pat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(Pat.var("x") + Pat.var("x") * Pat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
(Pat.var("x") + Pat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((Pat.var("x") / Pat.var("x2")) / Pat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (Pat.var("x") + Pat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(Pat.var().where(Pat.var("val"), Pat.var("val")), lambda val: val),
|
||||
(Pat.cvar("gate", vec=False).where(Pat.var("c0"), Pat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(Pat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
(Pat.maximum(Pat.var("x"), Pat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# TODO: why does this rule break beautiful_mnist?
|
||||
#((Pat.var("x")+Pat.var("z")).maximum(Pat.var("y")+Pat.var("z")), lambda x,y,z: x.maximum(y) + z),
|
||||
((Pat.var("x")*Pat.cvar("c1")).maximum(Pat.var("x")*Pat.cvar("c2")), max_var_const),
|
||||
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
|
||||
((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
||||
# ** two stage ALU folding **
|
||||
((Pat.var("x") + Pat.cvar("c1")) + Pat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
|
||||
((Pat.var("x") * Pat.cvar("c1")) * Pat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
||||
((Pat.var("x") & Pat.cvar("c1")) & Pat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
|
||||
((Pat.var("x") | Pat.cvar("c1")) | Pat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
|
||||
((Pat.cvar("c0") + Pat.var("x")) < Pat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
||||
((Pat.var("x") // Pat.cvar("c1")) // Pat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
||||
((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
|
||||
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
||||
((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
|
||||
((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
|
||||
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
||||
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
||||
# ** lt **
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((Pat.cvar("c0", vec=False)*Pat.var("x", dtype=dtypes.ints)).lt(Pat.cvar("c1", vec=False)),
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
|
||||
# c0*x<c1 for negative int c0 and non-positive c1
|
||||
((Pat.cvar("c0", vec=False)*Pat.var("x", dtype=dtypes.ints)).lt(Pat.cvar("c1", vec=False)),
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
||||
# x//c0<c1 for positive int c0
|
||||
((Pat.var("x", dtype=dtypes.ints)//Pat.cvar("c0", vec=False)).lt(Pat.cvar("c1", vec=False)),
|
||||
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
|
||||
# mul add lt
|
||||
(((Pat.cvar("c0", vec=False)*Pat.var("x"))+Pat.var("x2")).lt(Pat.cvar("c1", vec=False)),
|
||||
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
||||
(Pat(Ops.ALU, arg=BinaryOps.ADD, src=(Pat.var("x"), Pat.cvar("c1"))) + Pat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
(Pat(Ops.ALU, arg=BinaryOps.MUL, src=(Pat.var("x"), Pat.cvar("c1"))) * Pat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
# *** rules from symbolic ***
|
||||
# unrolled arange div folding
|
||||
(Pat(Ops.ALU, name="divs", src=[Pat(), Pat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
|
||||
(UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
|
||||
# generic lt folding
|
||||
(Pat.var("x", dtypes.sints).lt(Pat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
# not x < 1 -> X > 0
|
||||
(Pat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
(UPat.var("x", dtypes.ints).lt(1).ne(True), lambda x: newx.lt(1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
# ** div **
|
||||
# # div folding
|
||||
(Pat.var("x", dtypes.sints) // Pat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
|
||||
(UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=div_folding(x,c.arg)) is not None else None),
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(Pat.var("x") % Pat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
])
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
# ** combine terms (opinionated) **
|
||||
(-1 * (Pat.var("x") + Pat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((Pat.var("x", dtypes.ints) + Pat.var("y")) * Pat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
_substitute = PatternMatcher([(Pat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
|
||||
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
|
||||
renderer = PatternMatcher([
|
||||
(Pat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
(Pat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(Pat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(Pat(Ops.BIND, src=Pat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
|
||||
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(Pat(Ops.ALU, src=Pat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
])
|
||||
|
||||
# *** what was symbolic.py ***
|
||||
|
||||
@@ -2,66 +2,66 @@ from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, Pat, cast_float_to_bf16
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(Pat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(Pat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
|
||||
(Pat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
(Pat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
|
||||
(Pat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
||||
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
||||
# r method accesses
|
||||
(Pat(Ops.RANGE, name="x"),
|
||||
(UPat(Ops.RANGE, name="x"),
|
||||
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
|
||||
(Pat(Ops.VECTORIZE, name="x"),
|
||||
(UPat(Ops.VECTORIZE, name="x"),
|
||||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
|
||||
(Pat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
|
||||
(Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
||||
(Pat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
|
||||
(Pat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(Pat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(Pat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
||||
# const
|
||||
(Pat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
|
||||
(Pat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
|
||||
(Pat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
|
||||
(Pat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
|
||||
(Pat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
|
||||
(Pat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
||||
(Pat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
||||
(Pat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
|
||||
(UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
|
||||
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
# consts are rendered to larger type and casted
|
||||
(Pat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
|
||||
(Pat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
|
||||
(Pat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
|
||||
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
|
||||
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
|
||||
# default const render
|
||||
(Pat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var('idx'))),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"),
|
||||
(Pat(Ops.LOAD, src=(Pat.var('bidx'), Pat.var("var"), Pat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(Pat(Ops.LOAD, src=(Pat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
||||
(Pat(Ops.STORE, src=(Pat.var('bidx'), Pat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
# alu/gep
|
||||
(Pat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
|
||||
(UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
|
||||
*([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(Pat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(Pat(Ops.BITCAST, name="x"),
|
||||
(UPat(Ops.BITCAST, name="x"),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(Pat(Ops.STORE, dtype=dtypes.void, src=(Pat(), Pat(), Pat(dtype=dtypes.bool)), name="store"),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(Pat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
(UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
])
|
||||
|
||||
def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
@@ -214,13 +214,13 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
||||
# load/store image (OpenCL)
|
||||
(Pat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))), Pat.var("var"), Pat.var("gate"))),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
(Pat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))),)),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
||||
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
||||
(Pat(Ops.STORE, src=(Pat.var('buf').index(Pat.var('idx', dtypes.int.vec(2))), Pat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
|
||||
]) + base_rewrite
|
||||
|
||||
@@ -234,8 +234,8 @@ class IntelRenderer(OpenCLRenderer):
|
||||
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(Pat(Ops.CAST, dtype=dtypes.bfloat16, src=(Pat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
|
||||
(Pat(Ops.CAST, dtype=dtypes.float, src=(Pat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
|
||||
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
|
||||
]) + OpenCLRenderer.string_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
@@ -272,13 +272,13 @@ class MetalRenderer(CStyleLanguage):
|
||||
# upcast to float32 all the ops that don't support bfloat16
|
||||
extra_matcher = PatternMatcher([
|
||||
# NOTE: this is copied from PTX
|
||||
*[(Pat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
|
||||
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
|
||||
for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
|
||||
]) + extra_pm
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(Pat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
@@ -387,20 +387,20 @@ class AMDRenderer(CStyleLanguage):
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
extra_matcher = PatternMatcher([
|
||||
# cast bfloat16 alus to float
|
||||
(Pat(Ops.ALU, arg=TernaryOps.WHERE, src=(Pat.var("b"), Pat.var("x", dtype=dtypes.bfloat16), Pat.var("y", dtype=dtypes.bfloat16))),
|
||||
(UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
(Pat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
|
||||
(Pat(Ops.ALU, dtypes.bool, name="alu", src=(Pat.var("x", dtype=dtypes.bfloat16), Pat.var("y", dtype=dtypes.bfloat16))),
|
||||
(UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
||||
# add float intermediate casting for bfloat16
|
||||
(Pat(Ops.CAST, name="x", src=Pat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(Pat(Ops.CAST, dtypes.bfloat16, Pat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
# bfloat16 casting
|
||||
(Pat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
||||
(Pat(Ops.CAST, dtype=dtypes.float, src=Pat.var("x", dtype=dtypes.bfloat16)),
|
||||
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
|
||||
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
||||
(Pat(Ops.CAST, dtype=dtypes.bfloat16, src=Pat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
||||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
||||
|
||||
def render_vector_prefix(self, dtype:DType) -> str:
|
||||
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, Pat
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -35,23 +35,23 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
ptx_matcher = PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
(Pat.var('x', dtype=dtypes.bool).ne(Pat.var('y')), lambda x,y: x^y),
|
||||
(Pat.var('x', dtype=dtypes.bool).lt(Pat.var('y')), lambda x,y: (x^True)&y),
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
||||
# upcast to float32 all the ops that don't support half
|
||||
*[(Pat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
# load/store bool -> uint8
|
||||
(Pat(Ops.LOAD, dtypes.bool, src=(Pat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
||||
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
|
||||
(Pat(Ops.STORE, src=(Pat(dtype=dtypes.int64), Pat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
(Pat(Ops.INDEX, src=(Pat.var("buf"), Pat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
|
||||
(Pat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
|
||||
# ptx shr and shl instructions require y to be uint
|
||||
(Pat.var("x") << Pat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
|
||||
(Pat.var("x") >> Pat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
|
||||
@@ -25,7 +25,7 @@ class GraphRewriteMetadata:
|
||||
kernel_name: Optional[str]
|
||||
"""The kernel calling graph_rewrite"""
|
||||
upats: List[Tuple[Tuple[str, int], str, float]]
|
||||
"""List of all the applied Pats"""
|
||||
"""List of all the applied UPats"""
|
||||
|
||||
@dataclass
|
||||
class GraphRewriteDetails(GraphRewriteMetadata):
|
||||
|
||||
Reference in New Issue
Block a user