try all matches in the function (#7288)

This commit is contained in:
George Hotz
2024-10-25 13:17:04 +07:00
committed by GitHub
parent bcf0537653
commit 004af512e6
2 changed files with 17 additions and 3 deletions

View File

@@ -11,6 +11,18 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_upat_any(self):
def test(a, x=None, y=None, z=None):
#print(x,y,z)
if y is not None: return a+y
matcher = PatternMatcher([
(UPat.var("a")+UPat.any(UPat.var("x"), UPat.var("y"), UPat.var("z")), test),
])
v1 = UOp.variable("a", 0, 10)
v2 = UOp.variable("b", 0, 10)
c1 = v1+v2
self.assertEqual(matcher.rewrite(c1), c1)
@unittest.skip("closures aren't supported on pattern matchers")
def test_match_sz_0(self):
match_cnt = 0

View File

@@ -589,9 +589,10 @@ class UPat(MathTrait):
class UPatAny(UPat):
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
ret = []
for x in self.src[0]:
if (match:=x.match(uop, store.copy())): return match
return []
if (match:=x.match(uop, store.copy())): ret.extend(match)
return ret
def deconstruct_function(fxn:Callable) -> Tuple:
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
@@ -624,7 +625,8 @@ class PatternMatcher:
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
if not early_reject.issubset(ler): continue
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
for match in p.match(uop, {}):
if (ret:=(fxn(ctx, **match) if ctx is not None else fxn(**match))) is not None: return ret
return None
# *** tracking pattern matcher ***