From 004af512e60e9155f31343bef1ac2dff15328da0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:17:04 +0700 Subject: [PATCH] try all matches in the function (#7288) --- test/unit/test_pattern_matcher.py | 12 ++++++++++++ tinygrad/ops.py | 8 +++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 2d29e6a720..22c9c6f630 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -11,6 +11,18 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) + def test_upat_any(self): + def test(a, x=None, y=None, z=None): + #print(x,y,z) + if y is not None: return a+y + matcher = PatternMatcher([ + (UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test), + ]) + v1 = UOp.variable("a", 0, 10) + v2 = UOp.variable("b", 0, 10) + c1 = v1+v2 + self.assertEqual(matcher.rewrite(c1), c1) + @unittest.skip("closures aren't supported on pattern matchers") def test_match_sz_0(self): match_cnt = 0 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e73b17daf3..61476c98d4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -589,9 +589,10 @@ class UPat(MathTrait): class UPatAny(UPat): def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: + ret = [] for x in self.src[0]: - if (match:=x.match(uop, store.copy())): return match - return [] + if (match:=x.match(uop, store.copy())): ret.extend(match) + return ret def deconstruct_function(fxn:Callable) -> Tuple: new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names} @@ -624,7 +625,8 @@ class PatternMatcher: ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))]) for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])): if not early_reject.issubset(ler): continue - if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret + for match in p.match(uop, {}): + if (ret:=(fxn(ctx, **match) if ctx is not None else fxn(**match))) is not None: return ret return None # *** tracking pattern matcher ***