mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix off-by-one error in st_equal (#3131)
* fix off by one error * whitespace
This commit is contained in:
@@ -22,7 +22,7 @@ def st_equal(st1, st2) -> bool:
|
||||
idx = Variable("idx", 0, prod(st1.shape)-1)
|
||||
st1_idx, st1_valid = st1.expr_node(idx)
|
||||
st2_idx, st2_valid = st2.expr_node(idx)
|
||||
for i in range(idx.min, idx.max):
|
||||
for i in range(idx.min, idx.max + 1):
|
||||
st1_off = sym_infer(st1_idx, {idx: i})
|
||||
st2_off = sym_infer(st2_idx, {idx: i})
|
||||
st1_v = sym_infer(st1_valid, {idx: i})
|
||||
@@ -84,6 +84,13 @@ class TestShapeTrackerAdd(unittest.TestCase):
|
||||
st.reshape( (4, 3) )
|
||||
assert st_equal(backup + st.sts[1], st.sts[0])
|
||||
|
||||
def test_off_by_one(self):
|
||||
st1 = ShapeTracker(views=(View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True),
|
||||
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
||||
st2 = ShapeTracker(views=(View(shape=(4,), strides=(1,), offset=0, mask=None, contiguous=True),
|
||||
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
||||
assert not (st_equal(st1, st2))
|
||||
|
||||
class TestShapeTrackerAddVariable(unittest.TestCase):
|
||||
def test_self_add(self):
|
||||
j = Variable("j", 0, 20).bind(10)
|
||||
|
||||
Reference in New Issue
Block a user