fix tests

This commit is contained in:
George Hotz
2026-05-04 17:33:01 -07:00
parent 51b13466dd
commit ff1258feef
3 changed files with 10 additions and 11 deletions

View File

@@ -428,7 +428,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([w, red])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg==5
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
@@ -438,7 +438,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([w])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD: assert u.src[1].arg==5
if u.op is Ops.LOAD: assert u.src[2].arg==5
def test_where_on_gated_load_with_cast(self):
ridx0 = UOp.range(100, 0)
@@ -451,7 +451,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([w, red])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg == 5
def test_where_on_casted_gated_load_extra_cond(self):
ridx0 = UOp.range(100, 0)

View File

@@ -287,12 +287,12 @@ pm_render = PatternMatcher([
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# Where after gated load becomes alt value
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c", dtype=dtypes.bool).logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat()).or_casted(), UPat.var("c"), UPat()), allow_any_len=True,
name="l").or_casted(), UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], l.src[1],
a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat()).or_casted(),
UPat.var("c", dtype=dtypes.bool).logical_not(), UPat()), allow_any_len=True, name="l").or_casted()), lambda c,l,a:
l.replace(src=(l.src[0], l.src[1], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View File

@@ -31,8 +31,7 @@ class Estimates:
if u.op in {Ops.LOAD, Ops.STORE}:
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
# TODO: is this correct? this all needs to be cleaned up
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
if len(u.src) > 2: dont_count = dont_count.union(u.src[1 if u.op is Ops.LOAD else 2].toposort())
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].toposort())
for u in uops: