mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
remove nop, use upat [run_process_replay] (#6489)
* remove nop, use upat [run_process_replay] * mypy passes * no wonder nothing worked * fixes
This commit is contained in:
@@ -5,16 +5,16 @@ from tinygrad import dtypes, Device
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
from tinygrad.ops import NOp, PatternMatcher
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
(NOp.cvar('x') + NOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
|
||||
(NOp.cvar('x') * NOp.cvar('y') * NOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
|
||||
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
|
||||
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
|
||||
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
])
|
||||
|
||||
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
|
||||
|
||||
@@ -7,7 +7,6 @@ from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UOps, UOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401
|
||||
from tinygrad.ops import NOp
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
@@ -389,10 +388,6 @@ class TestUOpStr(unittest.TestCase):
|
||||
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
|
||||
assert_equiv_uops(sink, eval(str(sink)))
|
||||
|
||||
def test_nop_str(self):
|
||||
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")
|
||||
assert str(eval(str(a))) == str(a)
|
||||
|
||||
def test_variable_const(self):
|
||||
# TODO: this is not possible after VALID.
|
||||
uop = UOp(UOps.CONST, dtypes.int, (), arg=Variable("a",1,10))
|
||||
|
||||
@@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(UPat(UOps.CONST, 0, name="x"), lambda x: x),
|
||||
(UPat(UOps.CONST, False, name="x"), lambda x: x),
|
||||
(UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x),
|
||||
(UPat(UOps.CONST, arg=0, name="x"), lambda x: x),
|
||||
(UPat(UOps.CONST, arg=False, name="x"), lambda x: x),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
@@ -47,7 +47,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_filter_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)], name="x"),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"),
|
||||
lambda x,c: x if c.arg in {1, -1} else None)
|
||||
])
|
||||
y1 = UOp(UOps.CONST, dtypes.int, arg=1)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Tuple, List
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import UOp
|
||||
|
||||
TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64}
|
||||
TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
|
||||
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
||||
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
|
||||
from tinygrad.ops import NOp, UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, all_same, partition
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
@@ -83,9 +83,9 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
||||
|
||||
float4_folding = PatternMatcher([
|
||||
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat({UOps.BARRIER, UOps.SINK}, src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.REDUCE), name="vec"), vectorize_reduce),
|
||||
(UPat(UOps.VECTORIZE, src=UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}), name="vec"), vectorize_alu),
|
||||
(UPat(UOps.VECTORIZE, src=UPat((UOps.ALU, UOps.CAST, UOps.BITCAST)), name="vec"), vectorize_alu),
|
||||
])
|
||||
|
||||
# ***** mod *****
|
||||
@@ -222,145 +222,147 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
|
||||
# VECTORIZE/GEP
|
||||
(NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
|
||||
*[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float,
|
||||
src=(NOp.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in ([2, 4, 8, 16] + ([256] if AMX else []))],
|
||||
*[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half,
|
||||
src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
|
||||
*[(UPat(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UPat(UOps.GEP, dtypes.float,
|
||||
src=(UPat.var('x', dtype=dtypes.float.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in ([2, 4, 8, 16] + ([256] if AMX else []))],
|
||||
*[(UPat(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UPat(UOps.GEP, dtypes.half,
|
||||
src=(UPat.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
|
||||
# tensor core with a 0 input is acc
|
||||
*[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))),
|
||||
*[(UPat(UOps.WMMA, src=(UPat(UOps.VECTORIZE, src=tuple(UPat.const(None, 0.0) for _ in range(i))), UPat.var(), UPat.var('acc'))),
|
||||
lambda acc: acc) for i in [2, 4, 8]],
|
||||
*[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))),
|
||||
*[(UPat(UOps.WMMA, src=(UPat.var(), UPat(UOps.VECTORIZE, src=tuple(UPat.const(None, 0.0) for _ in range(i))), UPat.var('acc'))),
|
||||
lambda acc: acc) for i in [2, 4, 8]],
|
||||
# tensor core cleanups
|
||||
*[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),)
|
||||
*[(UPat(UOps.REDUCE, src=(UPat(UOps.EXPAND, src=tuple(UPat(UOps.GEP, dtypes.float, src=(UPat.var('x'),), arg=i) for i in range(j)), name="expand"),)
|
||||
,name="reduce", allow_any_len=True), reduce_before_expand) for j in ([2,4,8] + ([16,256] if AMX else []))],
|
||||
(NOp.var("add") + NOp(UOps.WMMA, name="wmma"),
|
||||
(UPat.var("add") + UPat(UOps.WMMA, name="wmma"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
(NOp(UOps.ALU, dtype=dtypes.uint64, src=(NOp.var("x"), NOp.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("seed")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# extra arange loop folding because we don't fold adds. TODO: fold adds
|
||||
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
|
||||
NOp.var("idx2") + NOp.var("idx3"))
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng") +
|
||||
NOp.var("idx2"))
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
|
||||
UPat.var("idx2") + UPat.var("idx3")).lt(UPat.cvar("compval"))
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng") +
|
||||
UPat.var("idx2")).lt(UPat.cvar("compval"))
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (reduce)
|
||||
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
|
||||
.lt(UPat.cvar("compval"))
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (unrolled)
|
||||
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),),
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# unrolled arange div folding
|
||||
(NOp.var('divs') + NOp.cvar('c'), fold_unrolled_divs),
|
||||
(UPat.var('divs') + UPat.cvar('c'), fold_unrolled_divs),
|
||||
# indexing (with a multiply offset)!
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
|
||||
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).cast()*
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('add')+UPat.var('mul')*UPat(UOps.RANGE, name="rng")), name="ld"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
|
||||
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.RANGE, name="rng")), name="ld"),),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).cast()*
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat(UOps.RANGE, name="rng")), name="ld"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
|
||||
lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).where(
|
||||
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"), NOp.const(None, 0.0)),),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var('idx').eq(UPat(UOps.RANGE, name="rng")).where(
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('add')+UPat.var('mul')*UPat(UOps.RANGE, name="rng")), name="ld"), UPat.const(None, 0.0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
# max folding
|
||||
(NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
(UPat.max(UPat.var('x'), UPat.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# GEP/CAST const rules
|
||||
(NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.GEP, src=(UPat.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const_like(c.arg)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(NOp.var().where(NOp.var("val"), NOp.var("val")), lambda val: val),
|
||||
(NOp.cvar('gate').where(NOp.var('c0'), NOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar('gate').where(UPat.var('c0'), UPat.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
# ** self folding **
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP
|
||||
(NOp.var('x') + 0, lambda x: x), # x+0 -> x
|
||||
(NOp.var('x') * 1, lambda x: x), # x*1 -> x
|
||||
(NOp.var('x') // NOp.var('x'), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(NOp.var('x') // 1, lambda x: x), # x//1 -> x
|
||||
(NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
||||
(NOp.var('x') / NOp.var('x'), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((NOp.var("x") * NOp.var("x2")) / NOp.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
(NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c),
|
||||
(NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x),
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var('x'),)), lambda x: x), # a REDUCE without ranges is a NOOP
|
||||
(UPat.var('x') + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var('x') * 1, lambda x: x), # x*1 -> x
|
||||
(UPat.var('x') // UPat.var('x'), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(UPat.var('x') // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
||||
(UPat.var('x') / UPat.var('x'), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
(UPat.var('x', dtype=dtypes.bool) & UPat.cvar('c'), lambda x,c: x if c.arg else c),
|
||||
(UPat.var('x', dtype=dtypes.bool) | UPat.cvar('c'), lambda x,c: c if c.arg else x),
|
||||
# ** zero folding **
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(NOp.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
(UPat.var('x') * 0, lambda x: x.const_like(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# min==max -> CONST (slow!)
|
||||
(UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat((UOps.ALU, UOps.DEFINE_VAR), name='x'), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# ** load/store folding **
|
||||
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.load(NOp.var("buf"), NOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
||||
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.load(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
||||
# ** two stage add/mul folding **
|
||||
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x+(c1+c2)),
|
||||
((NOp.var("x") * NOp.cvar("c1")) * NOp.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
||||
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x+(c1+c2)),
|
||||
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
||||
# *** rules from symbolic ***
|
||||
# ** lt **
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
|
||||
((UPat.cvar('c0')*UPat.var('x')).lt(UPat.cvar('c1')),
|
||||
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
|
||||
# c0*x<c1 for negative int c0 and non-positive c1
|
||||
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
|
||||
((UPat.cvar('c0')*UPat.var('x')).lt(UPat.cvar('c1')),
|
||||
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if dtypes.is_int(x.dtype) and c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
||||
# mul add lt
|
||||
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
|
||||
(((UPat.cvar('c0')*UPat.var('x'))+UPat.var('x2')).lt(UPat.cvar('c1')),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
# generic lt folding
|
||||
(NOp.var('x').lt(NOp.cvar('c')),
|
||||
(UPat.var('x').lt(UPat.cvar('c')),
|
||||
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
# ** div **
|
||||
# # div folding
|
||||
(NOp.var('x') // NOp.cvar('c'), lambda x,c:
|
||||
(UPat.var('x') // UPat.cvar('c'), lambda x,c:
|
||||
newx if 0 < c.arg and not dtypes.is_unsigned(x.dtype) and (newx:=div_folding(x,c.arg)) is not None else None),
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(NOp.var('x') % NOp.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
(UPat.var('x') % UPat.cvar('c'), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# mul mod
|
||||
((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None),
|
||||
((UPat.cvar('c0')*UPat.var('x')) % UPat.cvar('c1'), lambda x,c0,c1: (x%(c1//c0))*c0 if c1.arg%c0.arg == 0 else None),
|
||||
# ** combine terms **
|
||||
(NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
(NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(NOp.var("x") + NOp.var("x") * NOp.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
(NOp.var("x") + NOp.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((NOp.var("x") // NOp.cvar("c0")) // NOp.cvar("c1"), lambda x,c0,c1: x//(c0*c1)), # (x//c0)//c1 -> x//(c0*c1)
|
||||
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
||||
(UPat.var('x')%UPat.cvar('c')+(UPat.var('x')//UPat.cvar('c'))*UPat.cvar('c'), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((UPat.var("x") // UPat.cvar("c0")) // UPat.cvar("c1"), lambda x,c0,c1: x//(c0*c1)), # (x//c0)//c1 -> x//(c0*c1)
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
((UPat.cvar("c0") + UPat.var("x")).lt(UPat.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((NOp.var("x") + NOp.var("y")) * NOp.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None),
|
||||
((UPat.var("x") + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None),
|
||||
# x!=0 -> (bool)x
|
||||
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
||||
(UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
||||
# bitwise noops
|
||||
((NOp.var("x") & NOp.var("x")), lambda x: x),
|
||||
((NOp.var("x") | NOp.var("x")), lambda x: x),
|
||||
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("gate").where(NOp.var("alt"), NOp.load(NOp.var("buf"), NOp.var("idx")))),
|
||||
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat.var("buf"), UPat.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# fold gated LOAD/STORE
|
||||
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
|
||||
(UPat.load(UPat.var("buf"), UPat.var("idx"), UPat.var("var"), UPat.const(dtypes.bool, True)),
|
||||
lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(UPat.load(UPat.var("buf"), UPat.var("idx"), UPat.var("var"), UPat.const(dtypes.bool, True), UPat.var("barrier")),
|
||||
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
||||
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var),
|
||||
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
|
||||
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)),
|
||||
(UPat.load(UPat.var(), UPat.var(), UPat.var("var"), UPat.const(dtypes.bool, False)), lambda var: var),
|
||||
(UPat.load(UPat.var(), UPat.var(), UPat.var("var"), UPat.const(dtypes.bool, False), UPat.var()), lambda var: var),
|
||||
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("val"), UPat.const(dtypes.bool, True)),
|
||||
lambda buf,idx,val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
|
||||
(NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
||||
(UPat.store(UPat.var(), UPat.var(), UPat.var(), UPat.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
||||
# remove NOOPs from SINK
|
||||
(NOp(UOps.SINK, name="root"),
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
|
||||
# ** move add consts to end (NOTE: this is still happening before constant folding) **
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]),
|
||||
lambda x,c1,y: (x+y)+c1),
|
||||
])
|
||||
|
||||
@@ -446,22 +448,22 @@ def create_gate(root:UOp) -> Optional[UOp]:
|
||||
|
||||
expander = PatternMatcher([
|
||||
# create gate MUST BE BEFORE expander
|
||||
(NOp(UOps.STORE, name="root"), create_gate),
|
||||
(UPat(UOps.STORE, name="root"), create_gate),
|
||||
# do expansion
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
|
||||
(NOp(UOps.CONTRACT, name="con"), do_contract),
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
|
||||
(UPat(UOps.CONTRACT, name="con"), do_contract),
|
||||
# remove EXPANDs from SINK
|
||||
(NOp(UOps.SINK, name="root"),
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg)
|
||||
if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None),
|
||||
# BARRIERs aren't actually expanded
|
||||
(NOp(UOps.BARRIER, src=(NOp(UOps.EXPAND, name="ex"),)),
|
||||
(UPat(UOps.BARRIER, src=(UPat(UOps.EXPAND, name="ex"),)),
|
||||
lambda ex: UOp(UOps.EXPAND, dtypes.void, (UOp(UOps.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
|
||||
# empty EXPAND is NOOP
|
||||
(NOp(UOps.EXPAND, src=(NOp.var('x'),), arg=()), lambda x: x),
|
||||
(UPat(UOps.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||
# EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
|
||||
(NOp(UOps.EXPAND, name="ex", src=tuple(NOp.var('x').gep(i)+NOp.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
(UPat(UOps.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
])
|
||||
|
||||
@@ -474,16 +476,16 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
|
||||
|
||||
reducer = PatternMatcher([
|
||||
(NOp(UOps.REDUCE, name="root"), do_reduce),
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu),
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu),
|
||||
# delete_redundant_gates (after expand, is this still needed?)
|
||||
(NOp(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
# late fixup of unfoldable image loads
|
||||
(UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
])
|
||||
|
||||
no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE}, name="x"),
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
|
||||
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \
|
||||
if x.dtype.scalar() == dtypes.pyint else None)])
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ reduceop_fusor = PatternMatcher([
|
||||
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.SWIZZLE, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
|
||||
100
tinygrad/ops.py
100
tinygrad/ops.py
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, Sequence, DefaultDict
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
|
||||
from enum import auto
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
|
||||
from tinygrad.helpers import pretty_print, prod, getenv, all_same, HashEnum
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
@@ -335,7 +335,7 @@ class UOps(HashEnum):
|
||||
ENDIF = auto()
|
||||
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
||||
|
||||
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
@@ -367,9 +367,6 @@ class UOp(MathTrait):
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self):
|
||||
return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg
|
||||
def commutative(self) -> bool:
|
||||
return (self.op is UOps.ALU and \
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
||||
# *** uop syntactic sugar
|
||||
@property
|
||||
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
|
||||
@@ -381,10 +378,19 @@ class UOp(MathTrait):
|
||||
return ret.arg
|
||||
def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
|
||||
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
|
||||
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
|
||||
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
|
||||
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
|
||||
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar(), (self,), i)
|
||||
@classmethod
|
||||
def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
|
||||
@classmethod
|
||||
def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src)
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return type(self)(UOps.ALU, out_dtype, (self,)+src, arg)
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b)
|
||||
@@ -395,15 +401,6 @@ class UOp(MathTrait):
|
||||
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
||||
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return type(self)(UOps.ALU, out_dtype, (self,)+src, arg)
|
||||
@classmethod
|
||||
def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
|
||||
@classmethod
|
||||
def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
|
||||
@property # parents with self
|
||||
@@ -608,39 +605,13 @@ def get_location() -> Tuple[str, int]:
|
||||
@functools.lru_cache(None)
|
||||
def lines(fn) -> List[str]: return open(fn).readlines()
|
||||
|
||||
@dataclass(frozen=True, eq=False, repr=False) # reuse repr from UOp
|
||||
class NOp(UOp):
|
||||
name: Optional[str] = None
|
||||
# NOTE: this is fine because None dtype in NOp means any dtype is valid.
|
||||
dtype: Optional[DType] = None # type: ignore
|
||||
src: Tuple[NOp, ...] = tuple()
|
||||
allow_any_len: bool = False
|
||||
location: Tuple[str, int] = field(default_factory=get_location)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
|
||||
|
||||
# this is needed so NOp has a different cache
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b)
|
||||
|
||||
@functools.cached_property
|
||||
def upat(self:NOp) -> UPat:
|
||||
return UPat(name=self.name, dtype=self.dtype, location=self.location) if self.op is UOps.NOOP else \
|
||||
UPat(self.op, self.arg, (list if self.commutative() else tuple)([src.upat for src in self.src]) or None, self.name,
|
||||
self.dtype, self.allow_any_len, location=self.location)
|
||||
|
||||
class UPat:
|
||||
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
|
||||
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False, location=None,
|
||||
class UPat(MathTrait):
|
||||
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None,
|
||||
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
|
||||
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
|
||||
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
|
||||
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
|
||||
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
self.arg, self.name = arg, name
|
||||
self.in_src = src
|
||||
self.src: Any = None
|
||||
@@ -660,6 +631,31 @@ class UPat:
|
||||
upat_match = [self.in_src] if isinstance(self.in_src, UPat) else ([] if self.in_src is None else self.src[0])
|
||||
self.early_reject = set((pp.op[0], pp.arg) for pp in upat_match if pp.op is not None and len(pp.op) == 1)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(UOps.CONST, dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(dtype:Optional[DType], b:ConstType|Variable): return UPat(UOps.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# copied from UOp
|
||||
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return type(self)(UOps.GEP, None, (self,), i)
|
||||
@classmethod
|
||||
def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
|
||||
@classmethod
|
||||
def store(cls, *src:UPat): return cls(UOps.STORE, dtypes.void, src)
|
||||
|
||||
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
|
||||
def alu(self, arg, *src:UPat):
|
||||
asrc = (self,)+src
|
||||
return type(self)(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype,
|
||||
list(asrc) if arg in COMMUTATIVE else asrc, arg)
|
||||
|
||||
def printable(self:UPat) -> str:
|
||||
try:
|
||||
return lines(self.location[0])[self.location[1]-1].strip()
|
||||
@@ -687,8 +683,8 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
return res
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:Sequence[Tuple[Union[UPat, NOp], Callable]]):
|
||||
self.patterns = [(p.upat if isinstance(p, NOp) else p, fxn) for p,fxn in patterns]
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list)
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
@@ -710,7 +706,7 @@ class PatternMatcher:
|
||||
TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 0)
|
||||
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
class TrackedPattenMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
super().__init__(patterns)
|
||||
for p,_ in self.patterns:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
|
||||
@@ -34,20 +34,21 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
shiftable_consts = set([2**i for i in range(64)])
|
||||
ptx_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
src=[UPat(UOps.CONST, name="const"), UPat(name="mul")]),
|
||||
lambda root, mul, const: UOp(UOps.ALU, root.dtype,
|
||||
(mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
|
||||
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
src=[UPat(UOps.CONST, name="const"), UPat(name="div")]),
|
||||
lambda root, div, const: UOp(UOps.ALU, root.dtype,
|
||||
(div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"),
|
||||
lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat(name="x", dtype=dtypes.bool),UPat(name="y")), name="root"),
|
||||
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x, UOp.const(dtypes.bool, True)), BinaryOps.CMPNE), y), BinaryOps.MUL)),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(name="non_muls"), UPat(UOps.ALU, arg=BinaryOps.MUL, name="muls")], name="root"),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
||||
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
|
||||
@@ -63,17 +64,17 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
||||
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
|
||||
# ptr_ar (load/store)
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
|
||||
UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
lambda root, alu, const: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
const*root.src[0].dtype.itemsize)+root.src[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
|
||||
UPat(UOps.CONST, name="const"))),
|
||||
lambda root, const: UOp(root.op, root.dtype,
|
||||
(root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
|
||||
UPat(name="alu"))), # no const here
|
||||
lambda root, alu: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
|
||||
Reference in New Issue
Block a user