From ff1258feef92a15c0ada183bb00481db83f63db3 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 4 May 2026 17:33:01 -0700 Subject: [PATCH] fix tests --- test/null/test_uop_graph.py | 6 +++--- tinygrad/codegen/late/devectorizer.py | 12 ++++++------ tinygrad/renderer/__init__.py | 3 +-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index ea84737d98..bb7414f35a 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -428,7 +428,7 @@ class TestUOpGraph(unittest.TestCase): uops = to_uops_list([w, red]) for u in uops: assert u.op is not Ops.WHERE - if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5 + if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg==5 def test_where_on_gated_load_folds_swapped_branches(self): ridx0 = UOp.range(100, 0) @@ -438,7 +438,7 @@ class TestUOpGraph(unittest.TestCase): uops = to_uops_list([w]) for u in uops: assert u.op is not Ops.WHERE - if u.op is Ops.LOAD: assert u.src[1].arg==5 + if u.op is Ops.LOAD: assert u.src[2].arg==5 def test_where_on_gated_load_with_cast(self): ridx0 = UOp.range(100, 0) @@ -451,7 +451,7 @@ class TestUOpGraph(unittest.TestCase): uops = to_uops_list([w, red]) for u in uops: assert u.op is not Ops.WHERE - if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5 + if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg == 5 def test_where_on_casted_gated_load_extra_cond(self): ridx0 = UOp.range(100, 0) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index dd15089e13..58759e8e9a 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -287,12 +287,12 @@ pm_render = PatternMatcher([ if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None), # Where after gated load becomes alt value # NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint) - (UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(), - UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+ - l.src[2:]).cast(a.dtype)), - (UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c", dtype=dtypes.bool).logical_not()).or_casted(),), - allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype - else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)), + (UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat()).or_casted(), UPat.var("c"), UPat()), allow_any_len=True, + name="l").or_casted(), UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], l.src[1], + a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)), + (UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat()).or_casted(), + UPat.var("c", dtype=dtypes.bool).logical_not(), UPat()), allow_any_len=True, name="l").or_casted()), lambda c,l,a: + l.replace(src=(l.src[0], l.src[1], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)), ]) # *** Ops.REDUCE -> Ops.DEFINE_ACC *** diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 044d6b28f2..b0900bafdd 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -31,8 +31,7 @@ class Estimates: if u.op in {Ops.LOAD, Ops.STORE}: # if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate)) - # TODO: is this correct? this all needs to be cleaned up - if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort()) + if len(u.src) > 2: dont_count = dont_count.union(u.src[1 if u.op is Ops.LOAD else 2].toposort()) elif u.op is Ops.IF: dont_count = dont_count.union(u.src[0].toposort()) for u in uops: