From bcf05376530eb7f10f3d8b9cda6ff4e4a8249cd7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 25 Oct 2024 10:37:51 +0700 Subject: [PATCH] canonicalize the order prereqs (#7283) * canonicalize the order * don't change that yet * that order isn't safe with uops --- tinygrad/codegen/linearize.py | 12 +++--------- tinygrad/ops.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 744f0d0da0..1f7d3fa32a 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,7 +1,7 @@ -from typing import List, Set, Dict, Tuple, Any, Optional +from typing import List, Set, Dict, Tuple import functools, heapq from tinygrad.ops import type_verify, END_FOR_UOP, UOp, UOps -from tinygrad.dtype import dtypes, DType +from tinygrad.dtype import dtypes from tinygrad.helpers import DEBUG def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]): @@ -57,15 +57,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: for x in u.src: fix_priority(x, priorities[u]) fix_priority(sink, 0) - @functools.lru_cache(None) - def tuplize(u:UOp) -> Tuple[int, Any, Optional[DType], Tuple]: - if u.op is UOps.ALU: arg = u.arg.value - else: arg = u.arg - return (u.op.value, arg, u.dtype, tuple(tuplize(x) for x in u.src)) - # NOTE: the compare should never make it all the way to u queue:List[Tuple[int, Tuple, UOp]] = [] - def push(u:UOp): heapq.heappush(queue, (priorities[u], tuplize(u), u)) + def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u)) for u in children: if in_degree[u] == 0: push(u) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index eaad66c9f1..e73b17daf3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -101,6 +101,10 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps. # the order of these UOps controls the order of the toposort class UOps(FastEnum): + # consts! + VCONST = auto() + CONST = auto() + # uops that aren't rendered SINK = auto() CONTIGUOUS = auto() @@ -119,8 +123,6 @@ class UOps(FastEnum): DEFINE_VAR = auto() DEFINE_LOCAL = auto() DEFINE_ACC = auto() - VCONST = auto() - CONST = auto() VALID = auto() SPECIAL = auto() NOOP = auto() @@ -205,7 +207,6 @@ class UOp(MathTrait): #if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src]) #if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}" self.op, self.dtype, self.src, self.arg = op, dtype, src, arg - # NOTE: this has to be done early now because the COMMUTATIVE op can only be created one way if op is UOps.ALU and arg in (BinaryOps.ADD, BinaryOps.MUL) and src[0].op is UOps.CONST and src[1].op is not UOps.CONST: self.src = self.src[::-1] def replace(self, **kwargs) -> UOp: @@ -223,6 +224,10 @@ class UOp(MathTrait): @property # parents with self def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} + @functools.cached_property + def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]: + return (self.op.value, self.arg.value if self.op is UOps.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src)) + # *** uop shape stuff *** @property