mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
feat: assert on bufferview math (#12772)
This commit is contained in:
@@ -433,17 +433,15 @@ class TestDiskTensorMovement(unittest.TestCase):
|
||||
t = Tensor(self.fn)
|
||||
self.assertListEqual(t[16:18].tolist(), [16,17])
|
||||
|
||||
# TODO: fix this! at least assert on it
|
||||
@unittest.expectedFailure
|
||||
def test_slice_read_cat(self):
|
||||
t = Tensor(self.fn)
|
||||
self.assertListEqual(Tensor.cat(t[16:18], t[20:22]).tolist(), [16,17,20,21])
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(Tensor.cat(t[16:18], t[20:22]).tolist(), [16,17,20,21])
|
||||
|
||||
# TODO: fix this! at least assert on it
|
||||
@unittest.expectedFailure
|
||||
def test_slice_sum(self):
|
||||
t = Tensor(self.fn)
|
||||
self.assertListEqual((t[16:18]+t[20:22]).tolist(), [16+20,17+21])
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual((t[16:18]+t[20:22]).tolist(), [16+20,17+21])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -218,7 +218,9 @@ def late_buffer_view(t:UOp, b:UOp):
|
||||
|
||||
# walk up for the INDEX
|
||||
x = t
|
||||
while not any(u.op is Ops.INDEX for u in x.src): x = x.src[0]
|
||||
while not any(u.op is Ops.INDEX for u in x.src):
|
||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||
x = x.src[0]
|
||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
|
||||
Reference in New Issue
Block a user