handle stride 0 variable reshape (#10536)

This commit is contained in:
George Hotz
2025-05-27 10:00:24 -07:00
committed by GitHub
parent 0515622d95
commit a07caaca0d
3 changed files with 48 additions and 2 deletions

View File

@@ -872,5 +872,46 @@ class TestRender(unittest.TestCase):
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "(ridx0<2)")
class TestVariableReshape(unittest.TestCase):
def test_reshape(self):
st = ShapeTracker.from_shape((3,))
st = st.reshape((Variable("i", 1, 10),))
assert len(st.views) == 1
def test_reshape_stride_0(self):
st = ShapeTracker.from_shape((3,), (0,))
st = st.reshape((Variable("i", 1, 10).bind(3),))
assert len(st.views) == 1, f"multiview {st}"
def test_reshape_bound(self):
st = ShapeTracker.from_shape((3,))
st = st.reshape((Variable("i", 1, 10).bind(3),))
assert len(st.views) == 1
def test_add(self):
st1 = ShapeTracker.from_shape((3,))
st2 = ShapeTracker.from_shape((Variable("i", 1, 10),))
st = st1+st2
assert len(st.views) == 1
def test_add_stride_0(self):
st1 = ShapeTracker.from_shape((3,), (0,))
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),), (0,))
st = st1+st2
assert len(st.views) == 1, f"multiview {st}"
def test_add_bound(self):
st1 = ShapeTracker.from_shape((3,))
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
st = st1+st2
assert len(st.views) == 1
def test_simplify(self):
st1 = ShapeTracker.from_shape((3,))
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
st = ShapeTracker((st1.views[0], st2.views[0]))
st = st.simplify()
assert len(st.views) == 1
if __name__ == '__main__':
unittest.main()

View File

@@ -160,7 +160,11 @@ class View:
if vm1.mask:
if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
if not all_int(vm1.shape): return None
if not all_int(vm1.shape):
# if all strides are 0 and vm2 is unmasked, return vm1
if all(x == 0 for x in vm2.strides+vm1.strides) and vm2.mask is None: return vm1
# TODO: handle more cases
return None
# Project vm1's offset and strides on to vm2.
origin = unravel(vm2.shape, vm1.offset)

View File

@@ -73,7 +73,8 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# Tensor const has a device and an unmasked ShapeTracker of stride 0
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)), lambda st: all(v.mask is None for v in st.st.views)),
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes