mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
added visitor pattern (#1669)
* added visitor pattern * pylint bug workaround * added tests, made abstract OpNode inherit from ABC * fixed assert * fix check of abstract classes in negative test * remove assert False
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_render
|
||||
from tinygrad.shape.symbolic import Node, MulNode, SumNode, Variable, NumNode, LtNode, sym_render
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
@@ -366,6 +366,62 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
b = NumNode(0) * a
|
||||
assert b == 0
|
||||
assert isinstance(b, NumNode)
|
||||
|
||||
def test_num_node_expand(self):
|
||||
a = NumNode(42)
|
||||
assert a.expand() == [a]
|
||||
|
||||
def test_variable_expand(self):
|
||||
a = Variable("a", 5, 7)
|
||||
assert a.expand() == [a]
|
||||
|
||||
def test_variable_expand_expr_none(self):
|
||||
a = Variable(None, 5, 7)
|
||||
assert a.expand() == [NumNode(5), NumNode(6), NumNode(7)]
|
||||
|
||||
def test_mul_node_expand(self):
|
||||
a = Variable(None, 5, 7)
|
||||
m = MulNode(a, 3)
|
||||
assert m.expand() == [NumNode(15), NumNode(18), NumNode(21)]
|
||||
|
||||
b = Variable("b", 1, 3)
|
||||
n = MulNode(b, 3)
|
||||
assert n.expand() == [Variable("b", 1, 3)*3]
|
||||
|
||||
def test_sum_node_expand(self):
|
||||
a = Variable(None, 1, 3)
|
||||
b = Variable("b", 5, 7)
|
||||
|
||||
s1 = SumNode([a, b])
|
||||
assert s1.expand() == [Variable.sum([NumNode(i),b]) for i in range(1,4)]
|
||||
|
||||
c = Variable(None, 5, 7)
|
||||
|
||||
s2 = SumNode([a, c])
|
||||
assert s2.expand() == [Variable.sum([NumNode(i),NumNode(j)]) for (i,j) in [(1,5), (1,6), (1,7), (2,5), (2,6), (2,7), (3,5), (3,6), (3,7)]]
|
||||
|
||||
def test_non_expandable_nodes(self):
|
||||
expandable_nodes = [Variable, NumNode, MulNode, SumNode]
|
||||
|
||||
def test_non_expandable_nodes_recursive(node_cls: Node):
|
||||
for node_subcls in node_cls.__subclasses__():
|
||||
# skip expandable classes
|
||||
if node_subcls in expandable_nodes:
|
||||
continue
|
||||
|
||||
# recurse over subclasses
|
||||
test_non_expandable_nodes_recursive(node_subcls)
|
||||
|
||||
# skip classes with abstract methods
|
||||
if len(node_subcls.__abstractmethods__) > 0:
|
||||
continue
|
||||
|
||||
# test that node expand is not implemented
|
||||
node = node_subcls.__new__(node_subcls)
|
||||
self.assertRaises(NotImplementedError, node.expand)
|
||||
|
||||
test_non_expandable_nodes_recursive(Node)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -99,16 +99,8 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
|
||||
return zip(new_idxs, new_values)
|
||||
return zip([[i] for i in range(len(values[0]))], zip(*values))
|
||||
|
||||
# TODO: generic visitor pattern?
|
||||
def expand_node(idx:Node) -> List[Node]:
|
||||
if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
|
||||
if isinstance(idx, NumNode): return [idx]
|
||||
if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
|
||||
if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
|
||||
raise NotImplementedError(idx)
|
||||
|
||||
def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
|
||||
for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):
|
||||
for x in itertools.product(*[idx.expand() for idx in idxs[::-1]]):
|
||||
yield x[::-1]
|
||||
|
||||
class MemOp(NamedTuple):
|
||||
@@ -144,7 +136,7 @@ class Linearizer(OptimizedKernel):
|
||||
def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[Token]:
|
||||
const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
|
||||
|
||||
expanded_nodes = [expand_node(idx) for idx in idxs]
|
||||
expanded_nodes = [idx.expand() for idx in idxs]
|
||||
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
|
||||
@@ -174,7 +166,7 @@ class Linearizer(OptimizedKernel):
|
||||
return ret
|
||||
|
||||
def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> None:
|
||||
expanded_nodes = [expand_node(idx) for idx in idxs]
|
||||
expanded_nodes = [idx.expand() for idx in idxs]
|
||||
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from abc import abstractmethod, ABC
|
||||
import functools
|
||||
import itertools
|
||||
from math import gcd
|
||||
from tinygrad.helpers import partition
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
|
||||
@@ -10,7 +11,7 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
|
||||
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
|
||||
|
||||
class Node:
|
||||
class Node(ABC):
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: int
|
||||
@@ -21,6 +22,8 @@ class Node:
|
||||
if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1]
|
||||
return ret
|
||||
def vars(self): return []
|
||||
def expand(self) -> List[Node]:
|
||||
raise NotImplementedError(self.__class__.__name__)
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
@@ -149,6 +152,7 @@ class Variable(Node):
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
def vars(self): return [self]
|
||||
def expand(self) -> List[Node]: return [self] if self.expr is not None else [Variable.num(j) for j in range(self.min, self.max+1)]
|
||||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
@@ -158,6 +162,7 @@ class NumNode(Node):
|
||||
def __index__(self): return self.b
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
def expand(self) -> List[Node]: return [self]
|
||||
|
||||
def create_node(ret:Node):
|
||||
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
||||
@@ -190,6 +195,7 @@ class MulNode(OpNode):
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
def expand(self) -> List[Node]: return [x*self.b for x in self.a.expand()]
|
||||
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
||||
@@ -265,6 +271,8 @@ class SumNode(RedNode):
|
||||
else: new_sum.append(x)
|
||||
return Node.__lt__(Node.sum(new_sum), b)
|
||||
return Node.__lt__(self, b)
|
||||
|
||||
def expand(self) -> List[Node]: return [Variable.sum(list(it)) for it in itertools.product(*[x.expand() for x in self.nodes])]
|
||||
|
||||
@property
|
||||
def flat_components(self): # recursively expand sumnode components
|
||||
|
||||
Reference in New Issue
Block a user