diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py index d25807c1e1..893764aee6 100644 --- a/test/unit/test_shapetracker_math.py +++ b/test/unit/test_shapetracker_math.py @@ -22,7 +22,7 @@ def st_equal(st1, st2) -> bool: idx = Variable("idx", 0, prod(st1.shape)-1) st1_idx, st1_valid = st1.expr_node(idx) st2_idx, st2_valid = st2.expr_node(idx) - for i in range(idx.min, idx.max): + for i in range(idx.min, idx.max + 1): st1_off = sym_infer(st1_idx, {idx: i}) st2_off = sym_infer(st2_idx, {idx: i}) st1_v = sym_infer(st1_valid, {idx: i}) @@ -84,6 +84,13 @@ class TestShapeTrackerAdd(unittest.TestCase): st.reshape( (4, 3) ) assert st_equal(backup + st.sts[1], st.sts[0]) + def test_off_by_one(self): + st1 = ShapeTracker(views=(View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True), + View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True))) + st2 = ShapeTracker(views=(View(shape=(4,), strides=(1,), offset=0, mask=None, contiguous=True), + View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True))) + assert not (st_equal(st1, st2)) + class TestShapeTrackerAddVariable(unittest.TestCase): def test_self_add(self): j = Variable("j", 0, 20).bind(10)