diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 38418936b6..0ad6f072a2 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -2,7 +2,7 @@ from typing import List import unittest, pytest from tinygrad import dtypes, Variable from tinygrad.helpers import DEBUG, Context -from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite +from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp from tinygrad.codegen.symbolic import sym from tinygrad.codegen import full_rewrite, full_rewrite_to_sink from tinygrad.codegen.expander import expander @@ -727,6 +727,20 @@ class TestIFUOps(unittest.TestCase): for st in sink.src: self.assertEqual(len(st.src), 2) +class TestUOpTags(unittest.TestCase): + def test_inc_by_one(self): + g = UOp.const(dtypes.int, 1) + UOp.const(dtypes.int, 1) + assert g.ssimplify() == 2 + pm_plus_1 = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x.replace(arg=x.arg+1, tag=1) if x.tag is None else None)]) + pm_strip_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) + g = graph_rewrite(g, pm_plus_1) + assert g.ssimplify() == 4 + g = graph_rewrite(g, pm_plus_1) + assert g.ssimplify() == 4 + g = graph_rewrite(g, pm_strip_tags) + assert g.ssimplify() == 4 + g = graph_rewrite(g, pm_plus_1) + assert g.ssimplify() == 6 if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 2a148fe191..d40aa99303 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -220,8 +220,9 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s class UOpMetaClass(type): ucache:dict[tuple, weakref.ReferenceType[UOp]] = {} - def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, metadata:Metadata|None=None, _buffer:Buffer|None=None): - if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret + def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None, + metadata:Metadata|None=None, _buffer:Buffer|None=None): + if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key)) for s in src: s.children.add(ref) if metadata is not None: all_metadata[created] = metadata @@ -242,21 +243,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass): dtype:DType = dtypes.void src:tuple[UOp, ...] = tuple() arg:Any = None + tag:Any = None children:set[weakref.ref[UOp]] = field(default_factory=set) def __del__(self): if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1) - if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None: + if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg, self.tag))) is not None: for s in self.src: s.children.discard(ref) del UOpMetaClass.ucache[k] def __reduce__(self): - args = [self.op, self.dtype, self.src, self.arg] + args = [self.op, self.dtype, self.src, self.arg, self.tag] args.append(self.metadata) if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized) return UOp, tuple(args) def replace(self, **kwargs) -> UOp: - new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg)) + new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), + kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag)) assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}" - if (self.op, self.dtype, self.src, self.arg) == new_args: return self + if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self return UOp(*new_args) @functools.cached_property def key(self) -> bytes: