diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 5f210175f8..9b0aed2b0b 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -433,17 +433,15 @@ class TestDiskTensorMovement(unittest.TestCase): t = Tensor(self.fn) self.assertListEqual(t[16:18].tolist(), [16,17]) - # TODO: fix this! at least assert on it - @unittest.expectedFailure def test_slice_read_cat(self): t = Tensor(self.fn) - self.assertListEqual(Tensor.cat(t[16:18], t[20:22]).tolist(), [16,17,20,21]) + with self.assertRaises(AssertionError): + self.assertListEqual(Tensor.cat(t[16:18], t[20:22]).tolist(), [16,17,20,21]) - # TODO: fix this! at least assert on it - @unittest.expectedFailure def test_slice_sum(self): t = Tensor(self.fn) - self.assertListEqual((t[16:18]+t[20:22]).tolist(), [16+20,17+21]) + with self.assertRaises(AssertionError): + self.assertListEqual((t[16:18]+t[20:22]).tolist(), [16+20,17+21]) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 2ce45af401..581862570c 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -218,7 +218,9 @@ def late_buffer_view(t:UOp, b:UOp): # walk up for the INDEX x = t - while not any(u.op is Ops.INDEX for u in x.src): x = x.src[0] + while not any(u.op is Ops.INDEX for u in x.src): + assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise" + x = x.src[0] x = next(u for u in x.src if u.op is Ops.INDEX) if len(shape) == 0: offset = x.src[1].arg