From bfcaa2f70e5864b89c48cce5776f6a3bb553e66e Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 28 Mar 2024 12:16:38 -0400 Subject: [PATCH] assert `__setitem__` if used other than disk (#3972) * assert `__setitem__` if used other than disk * that is not implemented --- test/imported/test_indexing.py | 3 ++- tinygrad/tensor.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index e2097bd234..c4968b67da 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -1525,7 +1525,8 @@ class TestNumpy(unittest.TestCase): def test_broaderrors_indexing(self): a = Tensor.zeros(5, 5) self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2])) - self.assertRaises(IndexError, a.__setitem__, ([0, 1], [0, 1, 2]), 0) + # TODO setitem + # self.assertRaises(IndexError, a.__setitem__, ([0, 1], [0, 1, 2]), 0) # TODO setitem ''' diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 91be40d883..e6bc975345 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -505,7 +505,9 @@ class Tensor: ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:]) return ret - def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v) + def __setitem__(self,indices,v): + if isinstance(self.device, str) and self.device.startswith("DISK"): return self.__getitem__(indices).assign(v) + raise NotImplementedError("not implemented yet") # NOTE: using slice is discouraged and things should migrate to pad and shrink def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: