mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
handle stride 0 variable reshape (#10536)
This commit is contained in:
@@ -872,5 +872,46 @@ class TestRender(unittest.TestCase):
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "(ridx0<2)")
|
||||
|
||||
class TestVariableReshape(unittest.TestCase):
|
||||
def test_reshape(self):
|
||||
st = ShapeTracker.from_shape((3,))
|
||||
st = st.reshape((Variable("i", 1, 10),))
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_reshape_stride_0(self):
|
||||
st = ShapeTracker.from_shape((3,), (0,))
|
||||
st = st.reshape((Variable("i", 1, 10).bind(3),))
|
||||
assert len(st.views) == 1, f"multiview {st}"
|
||||
|
||||
def test_reshape_bound(self):
|
||||
st = ShapeTracker.from_shape((3,))
|
||||
st = st.reshape((Variable("i", 1, 10).bind(3),))
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_add(self):
|
||||
st1 = ShapeTracker.from_shape((3,))
|
||||
st2 = ShapeTracker.from_shape((Variable("i", 1, 10),))
|
||||
st = st1+st2
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_add_stride_0(self):
|
||||
st1 = ShapeTracker.from_shape((3,), (0,))
|
||||
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),), (0,))
|
||||
st = st1+st2
|
||||
assert len(st.views) == 1, f"multiview {st}"
|
||||
|
||||
def test_add_bound(self):
|
||||
st1 = ShapeTracker.from_shape((3,))
|
||||
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
|
||||
st = st1+st2
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_simplify(self):
|
||||
st1 = ShapeTracker.from_shape((3,))
|
||||
st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),))
|
||||
st = ShapeTracker((st1.views[0], st2.views[0]))
|
||||
st = st.simplify()
|
||||
assert len(st.views) == 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -160,7 +160,11 @@ class View:
|
||||
if vm1.mask:
|
||||
if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
|
||||
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
if not all_int(vm1.shape): return None
|
||||
if not all_int(vm1.shape):
|
||||
# if all strides are 0 and vm2 is unmasked, return vm1
|
||||
if all(x == 0 for x in vm2.strides+vm1.strides) and vm2.mask is None: return vm1
|
||||
# TODO: handle more cases
|
||||
return None
|
||||
|
||||
# Project vm1's offset and strides on to vm2.
|
||||
origin = unravel(vm2.shape, vm1.offset)
|
||||
|
||||
@@ -73,7 +73,8 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
|
||||
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
||||
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)), lambda st: all(v.mask is None for v in st.st.views)),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
|
||||
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
|
||||
Reference in New Issue
Block a user