mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
canonicalize the order prereqs (#7283)
* canonicalize the order * don't change that yet * that order isn't safe with uops
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import List, Set, Dict, Tuple, Any, Optional
|
||||
from typing import List, Set, Dict, Tuple
|
||||
import functools, heapq
|
||||
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, UOps
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
|
||||
@@ -57,15 +57,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
for x in u.src: fix_priority(x, priorities[u])
|
||||
fix_priority(sink, 0)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def tuplize(u:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
|
||||
if u.op is UOps.ALU: arg = u.arg.value
|
||||
else: arg = u.arg
|
||||
return (u.op.value, arg, u.dtype, tuple(tuplize(x) for x in u.src))
|
||||
|
||||
# NOTE: the compare should never make it all the way to u
|
||||
queue:List[Tuple[int, Tuple, UOp]] = []
|
||||
def push(u:UOp): heapq.heappush(queue, (priorities[u], tuplize(u), u))
|
||||
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
||||
|
||||
for u in children:
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
@@ -101,6 +101,10 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
|
||||
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(FastEnum):
|
||||
# consts!
|
||||
VCONST = auto()
|
||||
CONST = auto()
|
||||
|
||||
# uops that aren't rendered
|
||||
SINK = auto()
|
||||
CONTIGUOUS = auto()
|
||||
@@ -119,8 +123,6 @@ class UOps(FastEnum):
|
||||
DEFINE_VAR = auto()
|
||||
DEFINE_LOCAL = auto()
|
||||
DEFINE_ACC = auto()
|
||||
VCONST = auto()
|
||||
CONST = auto()
|
||||
VALID = auto()
|
||||
SPECIAL = auto()
|
||||
NOOP = auto()
|
||||
@@ -205,7 +207,6 @@ class UOp(MathTrait):
|
||||
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
|
||||
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
|
||||
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
||||
# NOTE: this has to be done early now because the COMMUTATIVE op can only be created one way
|
||||
if op is UOps.ALU and arg in (BinaryOps.ADD, BinaryOps.MUL) and src[0].op is UOps.CONST and src[1].op is not UOps.CONST:
|
||||
self.src = self.src[::-1]
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
@@ -223,6 +224,10 @@ class UOp(MathTrait):
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
|
||||
@functools.cached_property
|
||||
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
|
||||
return (self.op.value, self.arg.value if self.op is UOps.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
||||
|
||||
# *** uop shape stuff ***
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user