fix off-by-one error in st_equal (#3131)

* fix off by one error

* whitespace
This commit is contained in:
Paul Gustafson
2024-01-15 11:32:13 -08:00
committed by GitHub
parent 44c05919c1
commit 6bb65cd02e

View File

@@ -22,7 +22,7 @@ def st_equal(st1, st2) -> bool:
idx = Variable("idx", 0, prod(st1.shape)-1)
st1_idx, st1_valid = st1.expr_node(idx)
st2_idx, st2_valid = st2.expr_node(idx)
for i in range(idx.min, idx.max):
for i in range(idx.min, idx.max + 1):
st1_off = sym_infer(st1_idx, {idx: i})
st2_off = sym_infer(st2_idx, {idx: i})
st1_v = sym_infer(st1_valid, {idx: i})
@@ -84,6 +84,13 @@ class TestShapeTrackerAdd(unittest.TestCase):
st.reshape( (4, 3) )
assert st_equal(backup + st.sts[1], st.sts[0])
def test_off_by_one(self):
st1 = ShapeTracker(views=(View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True),
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
st2 = ShapeTracker(views=(View(shape=(4,), strides=(1,), offset=0, mask=None, contiguous=True),
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
assert not (st_equal(st1, st2))
class TestShapeTrackerAddVariable(unittest.TestCase):
def test_self_add(self):
j = Variable("j", 0, 20).bind(10)