added visitor pattern (#1669)

* added visitor pattern

* pylint bug workaround

* added tests, made abstract OpNode inherit from ABC

* fixed assert

* fix check of abstract classes in negative test

* remove assert False
This commit is contained in:
Max Hahn
2023-08-30 18:03:44 +02:00
committed by GitHub
parent fdd7f282cb
commit f9cb31fdc2
3 changed files with 70 additions and 14 deletions

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
import unittest
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_render
from tinygrad.shape.symbolic import Node, MulNode, SumNode, Variable, NumNode, LtNode, sym_render
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
@@ -366,6 +366,62 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
b = NumNode(0) * a
assert b == 0
assert isinstance(b, NumNode)
def test_num_node_expand(self):
a = NumNode(42)
assert a.expand() == [a]
def test_variable_expand(self):
a = Variable("a", 5, 7)
assert a.expand() == [a]
def test_variable_expand_expr_none(self):
a = Variable(None, 5, 7)
assert a.expand() == [NumNode(5), NumNode(6), NumNode(7)]
def test_mul_node_expand(self):
a = Variable(None, 5, 7)
m = MulNode(a, 3)
assert m.expand() == [NumNode(15), NumNode(18), NumNode(21)]
b = Variable("b", 1, 3)
n = MulNode(b, 3)
assert n.expand() == [Variable("b", 1, 3)*3]
def test_sum_node_expand(self):
a = Variable(None, 1, 3)
b = Variable("b", 5, 7)
s1 = SumNode([a, b])
assert s1.expand() == [Variable.sum([NumNode(i),b]) for i in range(1,4)]
c = Variable(None, 5, 7)
s2 = SumNode([a, c])
assert s2.expand() == [Variable.sum([NumNode(i),NumNode(j)]) for (i,j) in [(1,5), (1,6), (1,7), (2,5), (2,6), (2,7), (3,5), (3,6), (3,7)]]
def test_non_expandable_nodes(self):
expandable_nodes = [Variable, NumNode, MulNode, SumNode]
def test_non_expandable_nodes_recursive(node_cls: Node):
for node_subcls in node_cls.__subclasses__():
# skip expandable classes
if node_subcls in expandable_nodes:
continue
# recurse over subclasses
test_non_expandable_nodes_recursive(node_subcls)
# skip classes with abstract methods
if len(node_subcls.__abstractmethods__) > 0:
continue
# test that node expand is not implemented
node = node_subcls.__new__(node_subcls)
self.assertRaises(NotImplementedError, node.expand)
test_non_expandable_nodes_recursive(Node)
if __name__ == '__main__':
unittest.main()

View File

@@ -99,16 +99,8 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
return zip(new_idxs, new_values)
return zip([[i] for i in range(len(values[0]))], zip(*values))
# TODO: generic visitor pattern?
def expand_node(idx:Node) -> List[Node]:
if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)]
if isinstance(idx, NumNode): return [idx]
if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)]
if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])]
raise NotImplementedError(idx)
def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]:
for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]):
for x in itertools.product(*[idx.expand() for idx in idxs[::-1]]):
yield x[::-1]
class MemOp(NamedTuple):
@@ -144,7 +136,7 @@ class Linearizer(OptimizedKernel):
def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[Token]:
const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
expanded_nodes = [expand_node(idx) for idx in idxs]
expanded_nodes = [idx.expand() for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
upcast_dim = self.get_upcast_dim(i)
@@ -174,7 +166,7 @@ class Linearizer(OptimizedKernel):
return ret
def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> None:
expanded_nodes = [expand_node(idx) for idx in idxs]
expanded_nodes = [idx.expand() for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
upcast_dim = self.get_upcast_dim(i)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import abstractmethod
from abc import abstractmethod, ABC
import functools
import itertools
from math import gcd
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
@@ -10,7 +11,7 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
class Node:
class Node(ABC):
b: Union[Node, int]
min: int
max: int
@@ -21,6 +22,8 @@ class Node:
if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1]
return ret
def vars(self): return []
def expand(self) -> List[Node]:
raise NotImplementedError(self.__class__.__name__)
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
@@ -149,6 +152,7 @@ class Variable(Node):
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
def vars(self): return [self]
def expand(self) -> List[Node]: return [self] if self.expr is not None else [Variable.num(j) for j in range(self.min, self.max+1)]
class NumNode(Node):
def __init__(self, num:int):
@@ -158,6 +162,7 @@ class NumNode(Node):
def __index__(self): return self.b
def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override
def expand(self) -> List[Node]: return [self]
def create_node(ret:Node):
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
@@ -190,6 +195,7 @@ class MulNode(OpNode):
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
def expand(self) -> List[Node]: return [x*self.b for x in self.a.expand()]
class DivNode(OpNode):
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
@@ -265,6 +271,8 @@ class SumNode(RedNode):
else: new_sum.append(x)
return Node.__lt__(Node.sum(new_sum), b)
return Node.__lt__(self, b)
def expand(self) -> List[Node]: return [Variable.sum(list(it)) for it in itertools.product(*[x.expand() for x in self.nodes])]
@property
def flat_components(self): # recursively expand sumnode components