mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
try all matches in the function (#7288)
This commit is contained in:
@@ -11,6 +11,18 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_upat_any(self):
|
||||
def test(a, x=None, y=None, z=None):
|
||||
#print(x,y,z)
|
||||
if y is not None: return a+y
|
||||
matcher = PatternMatcher([
|
||||
(UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test),
|
||||
])
|
||||
v1 = UOp.variable("a", 0, 10)
|
||||
v2 = UOp.variable("b", 0, 10)
|
||||
c1 = v1+v2
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
|
||||
@unittest.skip("closures aren't supported on pattern matchers")
|
||||
def test_match_sz_0(self):
|
||||
match_cnt = 0
|
||||
|
||||
@@ -589,9 +589,10 @@ class UPat(MathTrait):
|
||||
|
||||
class UPatAny(UPat):
|
||||
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
ret = []
|
||||
for x in self.src[0]:
|
||||
if (match:=x.match(uop, store.copy())): return match
|
||||
return []
|
||||
if (match:=x.match(uop, store.copy())): ret.extend(match)
|
||||
return ret
|
||||
|
||||
def deconstruct_function(fxn:Callable) -> Tuple:
|
||||
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
|
||||
@@ -624,7 +625,8 @@ class PatternMatcher:
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
|
||||
for match in p.match(uop, {}):
|
||||
if (ret:=(fxn(ctx, **match) if ctx is not None else fxn(**match))) is not None: return ret
|
||||
return None
|
||||
|
||||
# *** tracking pattern matcher ***
|
||||
|
||||
Reference in New Issue
Block a user