diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 223acceb04..c19672bf0d 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -872,5 +872,46 @@ class TestRender(unittest.TestCase): self.assertEqual(idx.render(), "((ridx0*3)+ridx1)") self.assertEqual(valid.render(), "(ridx0<2)") +class TestVariableReshape(unittest.TestCase): + def test_reshape(self): + st = ShapeTracker.from_shape((3,)) + st = st.reshape((Variable("i", 1, 10),)) + assert len(st.views) == 1 + + def test_reshape_stride_0(self): + st = ShapeTracker.from_shape((3,), (0,)) + st = st.reshape((Variable("i", 1, 10).bind(3),)) + assert len(st.views) == 1, f"multiview {st}" + + def test_reshape_bound(self): + st = ShapeTracker.from_shape((3,)) + st = st.reshape((Variable("i", 1, 10).bind(3),)) + assert len(st.views) == 1 + + def test_add(self): + st1 = ShapeTracker.from_shape((3,)) + st2 = ShapeTracker.from_shape((Variable("i", 1, 10),)) + st = st1+st2 + assert len(st.views) == 1 + + def test_add_stride_0(self): + st1 = ShapeTracker.from_shape((3,), (0,)) + st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),), (0,)) + st = st1+st2 + assert len(st.views) == 1, f"multiview {st}" + + def test_add_bound(self): + st1 = ShapeTracker.from_shape((3,)) + st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),)) + st = st1+st2 + assert len(st.views) == 1 + + def test_simplify(self): + st1 = ShapeTracker.from_shape((3,)) + st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),)) + st = ShapeTracker((st1.views[0], st2.views[0])) + st = st.simplify() + assert len(st.views) == 1 + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index bc7ba0716c..2fb971b4cc 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -160,7 +160,11 @@ class View: if vm1.mask: if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape))) - if not all_int(vm1.shape): return None + if not all_int(vm1.shape): + # if all strides are 0 and vm2 is unmasked, return vm1 + if all(x == 0 for x in vm2.strides+vm1.strides) and vm2.mask is None: return vm1 + # TODO: handle more cases + return None # Project vm1's offset and strides on to vm2. origin = unravel(vm2.shape, vm1.offset) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 3dbdc10b16..dd55b19076 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -73,7 +73,8 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ # Tensor const has a device and an unmasked ShapeTracker of stride 0 # NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum - (UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)), lambda st: all(v.mask is None for v in st.st.views)), + (UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)), + lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)), # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes