From f9cb31fdc2c8df2fdb84873924b6dbd8e67cd315 Mon Sep 17 00:00:00 2001 From: Max Hahn <40292613+Addiictet@users.noreply.github.com> Date: Wed, 30 Aug 2023 18:03:44 +0200 Subject: [PATCH] added visitor pattern (#1669) * added visitor pattern * pylint bug workaround * added tests, made abstract OpNode inherit from ABC * fixed assert * fix check of abstract classes in negative test * remove assert False --- test/unit/test_symbolic.py | 58 +++++++++++++++++++++++++++++++++- tinygrad/codegen/linearizer.py | 14 ++------ tinygrad/shape/symbolic.py | 12 +++++-- 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index c4c73865ca..b387a50113 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_render +from tinygrad.shape.symbolic import Node, MulNode, SumNode, Variable, NumNode, LtNode, sym_render class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s): @@ -366,6 +366,62 @@ class TestSymbolicSymbolicOps(unittest.TestCase): b = NumNode(0) * a assert b == 0 assert isinstance(b, NumNode) + + def test_num_node_expand(self): + a = NumNode(42) + assert a.expand() == [a] + + def test_variable_expand(self): + a = Variable("a", 5, 7) + assert a.expand() == [a] + + def test_variable_expand_expr_none(self): + a = Variable(None, 5, 7) + assert a.expand() == [NumNode(5), NumNode(6), NumNode(7)] + + def test_mul_node_expand(self): + a = Variable(None, 5, 7) + m = MulNode(a, 3) + assert m.expand() == [NumNode(15), NumNode(18), NumNode(21)] + + b = Variable("b", 1, 3) + n = MulNode(b, 3) + assert n.expand() == [Variable("b", 1, 3)*3] + + def test_sum_node_expand(self): + a = Variable(None, 1, 3) + b = Variable("b", 5, 7) + + s1 = SumNode([a, b]) + assert s1.expand() == [Variable.sum([NumNode(i),b]) for i in range(1,4)] + + c = Variable(None, 5, 7) + + s2 = SumNode([a, c]) + assert s2.expand() == [Variable.sum([NumNode(i),NumNode(j)]) for (i,j) in [(1,5), (1,6), (1,7), (2,5), (2,6), (2,7), (3,5), (3,6), (3,7)]] + + def test_non_expandable_nodes(self): + expandable_nodes = [Variable, NumNode, MulNode, SumNode] + + def test_non_expandable_nodes_recursive(node_cls: Node): + for node_subcls in node_cls.__subclasses__(): + # skip expandable classes + if node_subcls in expandable_nodes: + continue + + # recurse over subclasses + test_non_expandable_nodes_recursive(node_subcls) + + # skip classes with abstract methods + if len(node_subcls.__abstractmethods__) > 0: + continue + + # test that node expand is not implemented + node = node_subcls.__new__(node_subcls) + self.assertRaises(NotImplementedError, node.expand) + + test_non_expandable_nodes_recursive(Node) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 399201f78e..b791a0238b 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -99,16 +99,8 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True): return zip(new_idxs, new_values) return zip([[i] for i in range(len(values[0]))], zip(*values)) -# TODO: generic visitor pattern? -def expand_node(idx:Node) -> List[Node]: - if isinstance(idx, Variable): return [idx] if idx.expr is not None else [Variable.num(j) for j in range(idx.min, idx.max+1)] - if isinstance(idx, NumNode): return [idx] - if isinstance(idx, MulNode): return [x*idx.b for x in expand_node(idx.a)] - if isinstance(idx, SumNode): return [Variable.sum(list(it)) for it in itertools.product(*[expand_node(x) for x in idx.nodes])] - raise NotImplementedError(idx) - def expand_idxs(idxs:Sequence[Node]) -> Iterator[Tuple[Node, ...]]: - for x in itertools.product(*[expand_node(idx) for idx in idxs[::-1]]): + for x in itertools.product(*[idx.expand() for idx in idxs[::-1]]): yield x[::-1] class MemOp(NamedTuple): @@ -144,7 +136,7 @@ class Linearizer(OptimizedKernel): def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[Token]: const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc - expanded_nodes = [expand_node(idx) for idx in idxs] + expanded_nodes = [idx.expand() for idx in idxs] _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])] upcast_dim = self.get_upcast_dim(i) @@ -174,7 +166,7 @@ class Linearizer(OptimizedKernel): return ret def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> None: - expanded_nodes = [expand_node(idx) for idx in idxs] + expanded_nodes = [idx.expand() for idx in idxs] _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])] upcast_dim = self.get_upcast_dim(i) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 22a49f8925..68f5886e98 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -1,6 +1,7 @@ from __future__ import annotations -from abc import abstractmethod +from abc import abstractmethod, ABC import functools +import itertools from math import gcd from tinygrad.helpers import partition from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any @@ -10,7 +11,7 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node)) -class Node: +class Node(ABC): b: Union[Node, int] min: int max: int @@ -21,6 +22,8 @@ class Node: if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1] return ret def vars(self): return [] + def expand(self) -> List[Node]: + raise NotImplementedError(self.__class__.__name__) @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") @functools.cached_property @@ -149,6 +152,7 @@ class Variable(Node): def __init__(self, expr:Optional[str], nmin:int, nmax:int): self.expr, self.min, self.max = expr, nmin, nmax def vars(self): return [self] + def expand(self) -> List[Node]: return [self] if self.expr is not None else [Variable.num(j) for j in range(self.min, self.max+1)] class NumNode(Node): def __init__(self, num:int): @@ -158,6 +162,7 @@ class NumNode(Node): def __index__(self): return self.b def __eq__(self, other): return self.b == other def __hash__(self): return self.hash # needed with __eq__ override + def expand(self) -> List[Node]: return [self] def create_node(ret:Node): assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}" @@ -190,6 +195,7 @@ class MulNode(OpNode): return Node.__mod__(a, b) def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) + def expand(self) -> List[Node]: return [x*self.b for x in self.a.expand()] class DivNode(OpNode): def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div @@ -265,6 +271,8 @@ class SumNode(RedNode): else: new_sum.append(x) return Node.__lt__(Node.sum(new_sum), b) return Node.__lt__(self, b) + + def expand(self) -> List[Node]: return [Variable.sum(list(it)) for it in itertools.product(*[x.expand() for x in self.nodes])] @property def flat_components(self): # recursively expand sumnode components