mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
fix tests
This commit is contained in:
@@ -428,7 +428,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
uops = to_uops_list([w, red])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg==5
|
||||
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
@@ -438,7 +438,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD: assert u.src[1].arg==5
|
||||
if u.op is Ops.LOAD: assert u.src[2].arg==5
|
||||
|
||||
def test_where_on_gated_load_with_cast(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
@@ -451,7 +451,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
uops = to_uops_list([w, red])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg == 5
|
||||
|
||||
def test_where_on_casted_gated_load_extra_cond(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
|
||||
@@ -287,12 +287,12 @@ pm_render = PatternMatcher([
|
||||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# Where after gated load becomes alt value
|
||||
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
|
||||
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
|
||||
l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c", dtype=dtypes.bool).logical_not()).or_casted(),),
|
||||
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
|
||||
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat()).or_casted(), UPat.var("c"), UPat()), allow_any_len=True,
|
||||
name="l").or_casted(), UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], l.src[1],
|
||||
a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat()).or_casted(),
|
||||
UPat.var("c", dtype=dtypes.bool).logical_not(), UPat()), allow_any_len=True, name="l").or_casted()), lambda c,l,a:
|
||||
l.replace(src=(l.src[0], l.src[1], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
||||
@@ -31,8 +31,7 @@ class Estimates:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
|
||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
|
||||
# TODO: is this correct? this all needs to be cleaned up
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[1 if u.op is Ops.LOAD else 2].toposort())
|
||||
elif u.op is Ops.IF:
|
||||
dont_count = dont_count.union(u.src[0].toposort())
|
||||
for u in uops:
|
||||
|
||||
Reference in New Issue
Block a user