canonicalize the order prereqs (#7283)

* canonicalize the order

* don't change that yet

* that order isn't safe with uops
This commit is contained in:
George Hotz
2024-10-25 10:37:51 +07:00
committed by GitHub
parent 603d637105
commit bcf0537653
2 changed files with 11 additions and 12 deletions

View File

@@ -1,7 +1,7 @@
from typing import List, Set, Dict, Tuple, Any, Optional
from typing import List, Set, Dict, Tuple
import functools, heapq
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, UOps
from tinygrad.dtype import dtypes, DType
from tinygrad.dtype import dtypes
from tinygrad.helpers import DEBUG
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
@@ -57,15 +57,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
for x in u.src: fix_priority(x, priorities[u])
fix_priority(sink, 0)
@functools.lru_cache(None)
def tuplize(u:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
if u.op is UOps.ALU: arg = u.arg.value
else: arg = u.arg
return (u.op.value, arg, u.dtype, tuple(tuplize(x) for x in u.src))
# NOTE: the compare should never make it all the way to u
queue:List[Tuple[int, Tuple, UOp]] = []
def push(u:UOp): heapq.heappush(queue, (priorities[u], tuplize(u), u))
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
for u in children:
if in_degree[u] == 0: push(u)

View File

@@ -101,6 +101,10 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
# the order of these UOps controls the order of the toposort
class UOps(FastEnum):
# consts!
VCONST = auto()
CONST = auto()
# uops that aren't rendered
SINK = auto()
CONTIGUOUS = auto()
@@ -119,8 +123,6 @@ class UOps(FastEnum):
DEFINE_VAR = auto()
DEFINE_LOCAL = auto()
DEFINE_ACC = auto()
VCONST = auto()
CONST = auto()
VALID = auto()
SPECIAL = auto()
NOOP = auto()
@@ -205,7 +207,6 @@ class UOp(MathTrait):
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
# NOTE: this has to be done early now because the COMMUTATIVE op can only be created one way
if op is UOps.ALU and arg in (BinaryOps.ADD, BinaryOps.MUL) and src[0].op is UOps.CONST and src[1].op is not UOps.CONST:
self.src = self.src[::-1]
def replace(self, **kwargs) -> UOp:
@@ -223,6 +224,10 @@ class UOp(MathTrait):
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@functools.cached_property
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
return (self.op.value, self.arg.value if self.op is UOps.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
# *** uop shape stuff ***
@property