mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 08:57:24 +08:00
* basic blocks + cleanups * fixups * elif is better for future me * fuzz_schedule_max_paths * fix linter
487 lines
27 KiB
Python
487 lines
27 KiB
Python
from __future__ import annotations
|
|
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
|
import functools, itertools, heapq, math
|
|
from collections import defaultdict
|
|
from enum import Enum, auto
|
|
from dataclasses import dataclass, field
|
|
from tinygrad.dtype import ConstType, dtypes, DType
|
|
from tinygrad.shape.symbolic import sint, Variable
|
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
|
from tinygrad.helpers import prod, DEBUG, getenv
|
|
|
|
# the order of these UOps controls the order of the toposort
|
|
class UOps(Enum):
|
|
# ops that aren't rendered
|
|
SINK = auto(); VAR = auto() # noqa: E702
|
|
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
|
CONST = auto(); SPECIAL = auto() # noqa: E702
|
|
NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
|
|
# math ops
|
|
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
|
|
ALU = auto(); WMMA = auto() # noqa: E702
|
|
# memory/assignment ops
|
|
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
|
# control flow ops
|
|
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
|
|
# these two are not graph nodes
|
|
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
|
|
|
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
|
|
|
def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
|
|
@dataclass(frozen=True, eq=False)
|
|
class UOp:
|
|
op: UOps
|
|
dtype: Optional[DType] = None
|
|
src: Tuple[UOp, ...] = tuple()
|
|
arg: Any = None
|
|
def commutative(self) -> bool:
|
|
return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
|
|
@functools.cached_property
|
|
def cmp_tuple(self):
|
|
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
|
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
|
self.arg.value, self.dtype, self.src)
|
|
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
|
def __repr__(self):
|
|
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
|
|
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
|
|
def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
|
|
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
|
|
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
|
|
def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
|
|
def __sub__(self, x): return UOp.alu(BinaryOps.ADD, self, -ufix(self.dtype, x))
|
|
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
|
|
def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
|
|
def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
|
|
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
|
|
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
|
def ne(self, x): return UOp.alu(BinaryOps.CMPNE, self, ufix(self.dtype, x))
|
|
def eq(self, x): return -self.ne(x)
|
|
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
|
|
def ge(self, x): return -self.lt(x)
|
|
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
|
|
def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
|
|
@staticmethod
|
|
@functools.lru_cache(maxsize=None)
|
|
def const(dtype:Optional[DType], b:ConstType|Variable):
|
|
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
|
|
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
|
@staticmethod
|
|
def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
|
|
@staticmethod
|
|
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
|
@staticmethod
|
|
def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values()))
|
|
@staticmethod
|
|
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
|
|
@staticmethod
|
|
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
|
|
@functools.cached_property
|
|
def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
|
|
@property # parents with self
|
|
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
|
|
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
|
|
def divides(self, v):
|
|
if self.op is UOps.CONST:
|
|
return self.arg%v == 0
|
|
if self.op is UOps.ALU:
|
|
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
|
|
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
|
|
return False # generic false if we aren't sure
|
|
|
|
def type_verify(uops):
|
|
for u in uops:
|
|
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
|
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
|
if uop is UOps.DEFINE_ACC:
|
|
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
|
arg = src[0].arg
|
|
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
|
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
|
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count
|
|
if uop is UOps.VECTORIZE:
|
|
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
|
assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
|
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
|
if uop is UOps.STORE:
|
|
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
|
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
|
if uop is UOps.ALU:
|
|
if arg in UnaryOps:
|
|
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
|
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
|
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
|
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
|
elif arg is BinaryOps.IDIV:
|
|
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
|
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
|
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
|
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
|
# the distance to shift isn't typechecked
|
|
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
|
elif arg in BinaryOps:
|
|
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
|
elif arg == TernaryOps.WHERE:
|
|
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
|
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
|
|
|
def uop_alu_resolve(u:UOp) -> sint:
|
|
if u.op is UOps.CONST: return u.arg
|
|
if u.op is UOps.DEFINE_VAR: return u.arg
|
|
if u.op is UOps.SPECIAL: return u.arg[2]-1
|
|
if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
|
|
if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
|
|
if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
|
|
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
|
|
|
# *** simplification logic ***
|
|
|
|
@dataclass(frozen=True)
|
|
class UPat:
|
|
op: Optional[Union[UOps, Set[UOps]]] = None
|
|
arg: Any = None
|
|
src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
|
|
name: Optional[str] = None
|
|
dtype: Optional[Union[DType, Set[DType]]] = None
|
|
allow_len: Set[int] = field(default_factory=set)
|
|
allow_any_len: bool = False
|
|
|
|
@staticmethod
|
|
def compile(u: UOp, name:Optional[str]=None) -> UPat:
|
|
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
|
|
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None,
|
|
name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name))
|
|
|
|
T = TypeVar("T")
|
|
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
|
|
|
|
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
|
if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
|
|
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
|
|
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
|
|
if pat.op is not None and __unmatch(pat.op, uop.op): return False
|
|
if pat.src is None: return True
|
|
# only one if it's a tuple
|
|
# try all permutations if it's a list
|
|
# repeat if it's a UPat
|
|
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
|
|
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return False
|
|
new_store = store.copy()
|
|
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
|
|
store.update(new_store)
|
|
return True
|
|
return False
|
|
|
|
class PatternMatcher:
|
|
def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
|
|
self.patterns = patterns
|
|
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
|
# uop is required, arg is optional
|
|
for p,fxn in self.patterns:
|
|
if isinstance(p, UOp): p = UPat.compile(p)
|
|
assert p.op is not None
|
|
if isinstance(p.op, set):
|
|
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
|
else:
|
|
self.pdict[(p.op, p.arg)].append((p, fxn))
|
|
|
|
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
|
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
|
store: Dict[str, UOp] = {}
|
|
if _match(uop, p, store) and (ret:=fxn(**store)) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
|
return None
|
|
|
|
def sum_collapse(phi_input, loop, val1, val2):
|
|
for v1,v2 in [(val1, val2), (val2, val1)]:
|
|
if loop not in v1.parents:
|
|
loop_range = loop.src[1]-loop.src[0]
|
|
ret = v1*loop_range.cast(v1.dtype)
|
|
return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
|
|
return None
|
|
|
|
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
|
|
if getenv("DISABLE_LOOP_COLLAPSE") or not rng.arg[2]: return None # must be a reduce
|
|
if mval.arg >= 0 or loop_start.arg != 0:
|
|
# TODO: support and test this with other mvals and loop_starts
|
|
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
|
return None
|
|
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
|
|
return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
|
|
|
|
# this is symbolic 2.0
|
|
constant_folder = PatternMatcher([
|
|
# arange loop folding (early)
|
|
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
|
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
|
|
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
|
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
|
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
|
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
|
|
UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
|
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
|
|
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
|
# sum collapse to mul (with possible GEP)
|
|
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
|
|
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
|
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
|
|
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
|
# deal with UNMUL
|
|
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
|
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
|
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
|
|
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
|
|
# max on special can go away (TODO: special should be variable, same thing applies)
|
|
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
|
# const rules
|
|
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
|
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
|
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
|
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
|
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
|
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
|
|
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
|
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
|
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
|
|
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
|
# max -2147483648
|
|
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
|
# bool < False is always false, True < bool is always false
|
|
(UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
|
|
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
|
|
# a conditional with the same results either way is a noop, also fold const conditionals
|
|
(UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
|
|
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
|
# ** constant folding **
|
|
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
|
# ** self folding **
|
|
(-(-UOp.var('x')), lambda x: x), # -(-x) -> x
|
|
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
|
|
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
|
|
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
|
|
(UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
|
|
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
|
|
(UOp.var('x') // 1, lambda x: x), # x//1 -> x
|
|
(UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
|
(UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
|
|
(UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
|
|
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
|
|
# ** zero folding **
|
|
#x*0 -> 0 or 0*x -> 0
|
|
#if x is nan or inf it should render the nan value.
|
|
# NOTE: this can be wrong for loaded NaN
|
|
(UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
|
(UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
|
# ** load/store folding **
|
|
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
|
# ** two stage add/sub folding **
|
|
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
|
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
|
|
# *** rules from symbolic ***
|
|
# two stage mul, (x*c1)*c2 = x*(c1*c2)
|
|
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
|
# x%1 -> 0
|
|
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
|
|
# (x*c0)+(x*c1) -> x*(c0+c1)
|
|
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
|
# (x*c0)+(y*c0) -> (x+y)*c0
|
|
#((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
|
# (x*c0)//c0 -> x
|
|
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
|
|
# (x*x2)/x2 -> x
|
|
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
|
|
# (x//c0)//c1 -> x//(c0*c1)
|
|
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
|
|
# (x/x1)/x2 -> x/(x1*x2)
|
|
((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
|
# c0 + x < c1 -> x < c1 - c0
|
|
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
|
|
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
|
# (x+x*c0)-> x*(c0+1)
|
|
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
|
|
# x!=0 -> (bool)x
|
|
(UOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
|
# bool != 1 -> not bool
|
|
(UOp.var("x", dtype=dtypes.bool).ne(1), lambda x: -x),
|
|
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
|
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
|
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
|
# store float4/float2 directly (remove VECTORIZE/GEP)
|
|
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
|
|
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
|
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.VECTORIZE, src=tuple(
|
|
UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
|
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
|
|
(UPat(UOps.VECTORIZE, name="root", src=tuple(
|
|
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
|
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
|
|
(UPat(UOps.VECTORIZE, name="root", src=tuple(
|
|
UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
|
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1))))),
|
|
# NEG/CMPLT -> CMPLT
|
|
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
|
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
|
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
|
(UPat(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
|
# fold gated LOAD/STORE
|
|
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
|
(UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.bool, True), UOp.cvar("var"), UOp.var("barrier")),
|
|
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
|
(UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var")), lambda var: var),
|
|
(UOp.load(UOp.var(), UOp.var(), UOp.const(dtypes.bool, False), UOp.cvar("var"), UOp.var()), lambda var: var),
|
|
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.bool, True)), UOp.store),
|
|
(UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
|
# remove NOOPs from SINK
|
|
(UPat(UOps.SINK, name="root"),
|
|
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
|
|
])
|
|
|
|
# *** uop graph ***
|
|
|
|
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
|
|
if u in children: return
|
|
children[u] = []
|
|
for x in u.src:
|
|
get_children_dfs(x, children, in_degree)
|
|
children[x].append(u)
|
|
in_degree[u] = len(u.src)
|
|
|
|
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
|
nodes: Dict[Tuple, UOp] = {}
|
|
replace: Dict[UOp, UOp] = {}
|
|
def __inner_rewrite(n:UOp) -> UOp:
|
|
if n in replace: return replace[n]
|
|
replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
|
if found := nodes.get(replace_source): replace[n] = found
|
|
else: nodes[replace_source] = replace[n] = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
|
|
return replace[n]
|
|
return __inner_rewrite(sink)
|
|
|
|
class UOpGraph:
|
|
def __init__(self, sinks:List[UOp]):
|
|
self.sinks: List[UOp] = sinks
|
|
# used by linearizer
|
|
self._uops: Optional[List[UOp]] = None
|
|
|
|
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
|
def __getitem__(self, index) -> UOp: return self.uops[index]
|
|
|
|
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
|
|
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
|
|
|
|
@property
|
|
def uops(self) -> List[UOp]:
|
|
if self._uops is None: self.linearize()
|
|
return cast(List[UOp], self._uops)
|
|
|
|
def graph(self):
|
|
from tinygrad.engine.graph import graph_uops
|
|
graph_uops(self.uops)
|
|
|
|
def print(self):
|
|
for i,u in enumerate(self):
|
|
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
|
|
|
def linearize(self, extra_pm:Optional[PatternMatcher]=None, do_type_verify=True):
|
|
# NOTE: relinearizering should be okay
|
|
#assert self._uops is None, "already linearized"
|
|
|
|
# fixup gated stores with an IF block to save extra local loads
|
|
@functools.lru_cache(None)
|
|
def _dfs(u:UOp, gate:UOp) -> UOp:
|
|
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
|
|
if_uop = UOp(UOps.IF, None, (gate, u.src[-1]))
|
|
return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg)
|
|
if (replace_source:=tuple(_dfs(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
|
|
return u
|
|
for i, s in enumerate(self.sinks[:]):
|
|
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s: self.sinks[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
|
|
sink = UOp(UOps.SINK, None, tuple(self.sinks))
|
|
|
|
# dedup all nodes and do graph rewrite
|
|
sink = graph_rewrite(sink, constant_folder)
|
|
if extra_pm: sink = graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
|
|
|
|
# filter nodes that don't link to a sink
|
|
# BFS toposort
|
|
children: Dict[UOp, List[UOp]] = {}
|
|
in_degree: Dict[UOp, int] = {}
|
|
get_children_dfs(sink, children, in_degree)
|
|
|
|
@functools.lru_cache(None)
|
|
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
|
if x.op is UOps.SINK: return set()
|
|
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
|
|
|
# scope children impact the toposort and END* insertion
|
|
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
|
|
|
queue:List[Tuple[int, UOp]] = []
|
|
def push(u:UOp):
|
|
priority = 0
|
|
# prefer uops that are loop children
|
|
for l, ss in scope_children.items():
|
|
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
|
heapq.heappush(queue, (priority, u))
|
|
|
|
for u in children:
|
|
if in_degree[u] == 0: push(u)
|
|
|
|
self._uops = []
|
|
while queue:
|
|
p,x = heapq.heappop(queue)
|
|
if DEBUG >= 7: print(p,x)
|
|
if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
|
|
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
|
self._uops.insert(idx, x)
|
|
else: self._uops.append(x)
|
|
for u in children[x]:
|
|
in_degree[u] -= 1
|
|
if in_degree[u] == 0: push(u)
|
|
|
|
for u in (self._uops):
|
|
if u.op in END_FOR_UOP: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
|
|
|
|
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
|
self._uops = self._uops[:-1]
|
|
|
|
if getenv("FUZZ_UOPS"):
|
|
from test.external.fuzz_uops import fuzz_uops
|
|
self._fuzz_paths = fuzz_uops(self)
|
|
if do_type_verify: type_verify(self.uops)
|
|
|
|
# *** checker functions ***
|
|
|
|
def flops_mem(self, ignore_indexing=False) -> Tuple[sint, sint]:
|
|
flops: sint = 0
|
|
mem: sint = 0
|
|
mults: sint = 1
|
|
mult_stack = []
|
|
dont_count: Set[UOp] = set()
|
|
if ignore_indexing:
|
|
for u in self.uops:
|
|
if u.op is UOps.LOAD:
|
|
dont_count = dont_count.union(u.src[1].sparents)
|
|
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
|
elif u.op is UOps.STORE:
|
|
dont_count = dont_count.union(u.src[1].sparents)
|
|
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
|
for u in self.uops:
|
|
if u.op is UOps.RANGE:
|
|
mult_stack.append(mults)
|
|
mults *= uop_alu_resolve(u.src[1])
|
|
elif u.op is UOps.ENDRANGE:
|
|
mults = mult_stack.pop(-1)
|
|
elif u.op is UOps.LOAD:
|
|
assert u.dtype is not None
|
|
mem += u.dtype.itemsize * mults
|
|
elif u.op is UOps.STORE:
|
|
assert u.src[2].dtype is not None
|
|
mem += u.src[2].dtype.itemsize * mults
|
|
elif u.op is UOps.ALU and u not in dont_count:
|
|
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
|
elif u.op is UOps.WMMA and u not in dont_count:
|
|
assert u.arg[1] is not None
|
|
flops += 2 * prod(u.arg[1]) // 32 * mults
|
|
return flops, mem
|