mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
support for uop tags (#10477)
* support for uop tags [pr] * test uop tags
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import List
|
||||
import unittest, pytest
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
from tinygrad.codegen import full_rewrite, full_rewrite_to_sink
|
||||
from tinygrad.codegen.expander import expander
|
||||
@@ -727,6 +727,20 @@ class TestIFUOps(unittest.TestCase):
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
class TestUOpTags(unittest.TestCase):
|
||||
def test_inc_by_one(self):
|
||||
g = UOp.const(dtypes.int, 1) + UOp.const(dtypes.int, 1)
|
||||
assert g.ssimplify() == 2
|
||||
pm_plus_1 = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x.replace(arg=x.arg+1, tag=1) if x.tag is None else None)])
|
||||
pm_strip_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
g = graph_rewrite(g, pm_plus_1)
|
||||
assert g.ssimplify() == 4
|
||||
g = graph_rewrite(g, pm_plus_1)
|
||||
assert g.ssimplify() == 4
|
||||
g = graph_rewrite(g, pm_strip_tags)
|
||||
assert g.ssimplify() == 4
|
||||
g = graph_rewrite(g, pm_plus_1)
|
||||
assert g.ssimplify() == 6
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -220,8 +220,9 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
|
||||
class UOpMetaClass(type):
|
||||
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, metadata:Metadata|None=None, _buffer:Buffer|None=None):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
||||
metadata:Metadata|None=None, _buffer:Buffer|None=None):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
||||
for s in src: s.children.add(ref)
|
||||
if metadata is not None: all_metadata[created] = metadata
|
||||
@@ -242,21 +243,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
dtype:DType = dtypes.void
|
||||
src:tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
tag:Any = None
|
||||
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
||||
def __del__(self):
|
||||
if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
||||
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
|
||||
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg, self.tag))) is not None:
|
||||
for s in self.src: s.children.discard(ref)
|
||||
del UOpMetaClass.ucache[k]
|
||||
def __reduce__(self):
|
||||
args = [self.op, self.dtype, self.src, self.arg]
|
||||
args = [self.op, self.dtype, self.src, self.arg, self.tag]
|
||||
args.append(self.metadata)
|
||||
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
||||
return UOp, tuple(args)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
|
||||
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
|
||||
kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag))
|
||||
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
|
||||
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
||||
if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self
|
||||
return UOp(*new_args)
|
||||
@functools.cached_property
|
||||
def key(self) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user