support for uop tags (#10477)

* support for uop tags [pr]

* test uop tags
This commit is contained in:
George Hotz
2025-05-22 19:53:48 -07:00
committed by GitHub
parent 8cc2dff4d8
commit 9fc01c1e03
2 changed files with 24 additions and 7 deletions

View File

@@ -2,7 +2,7 @@ from typing import List
import unittest, pytest
from tinygrad import dtypes, Variable
from tinygrad.helpers import DEBUG, Context
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp
from tinygrad.codegen.symbolic import sym
from tinygrad.codegen import full_rewrite, full_rewrite_to_sink
from tinygrad.codegen.expander import expander
@@ -727,6 +727,20 @@ class TestIFUOps(unittest.TestCase):
for st in sink.src:
self.assertEqual(len(st.src), 2)
class TestUOpTags(unittest.TestCase):
def test_inc_by_one(self):
g = UOp.const(dtypes.int, 1) + UOp.const(dtypes.int, 1)
assert g.ssimplify() == 2
pm_plus_1 = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x.replace(arg=x.arg+1, tag=1) if x.tag is None else None)])
pm_strip_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
g = graph_rewrite(g, pm_plus_1)
assert g.ssimplify() == 4
g = graph_rewrite(g, pm_plus_1)
assert g.ssimplify() == 4
g = graph_rewrite(g, pm_strip_tags)
assert g.ssimplify() == 4
g = graph_rewrite(g, pm_plus_1)
assert g.ssimplify() == 6
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -220,8 +220,9 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
class UOpMetaClass(type):
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, metadata:Metadata|None=None, _buffer:Buffer|None=None):
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
metadata:Metadata|None=None, _buffer:Buffer|None=None):
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
for s in src: s.children.add(ref)
if metadata is not None: all_metadata[created] = metadata
@@ -242,21 +243,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
dtype:DType = dtypes.void
src:tuple[UOp, ...] = tuple()
arg:Any = None
tag:Any = None
children:set[weakref.ref[UOp]] = field(default_factory=set)
def __del__(self):
if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg, self.tag))) is not None:
for s in self.src: s.children.discard(ref)
del UOpMetaClass.ucache[k]
def __reduce__(self):
args = [self.op, self.dtype, self.src, self.arg]
args = [self.op, self.dtype, self.src, self.arg, self.tag]
args.append(self.metadata)
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
return UOp, tuple(args)
def replace(self, **kwargs) -> UOp:
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag))
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self
return UOp(*new_args)
@functools.cached_property
def key(self) -> bytes: