diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index f6b596056d..d35cd3f320 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -165,7 +165,8 @@ def isin_tensor_tensor_out(x, y, *, assume_unique=False, invert=False, out=None) @torch.library.impl("aten::randperm.generator_out", "privateuseone") def randperm_generator(n, generator=None, out=None): - return out.copy_(wrap(Tensor.randperm(n, generator=generator, device=unwrap(out).device))) + if generator is not None: raise NotImplementedError("tinygrad torch backend does not support torch.Generator for randperm") + return out.copy_(wrap(Tensor.randperm(n, device=unwrap(out).device))) @torch.library.impl("aten::_linalg_eigh", "privateuseone") # TODO: move to tinygrad diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index 8e195f0f1d..05066fa71b 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -56,37 +56,6 @@ def diagonal(tensor:Tensor) -> Tensor: def unravel_index(tensor, shape): pass -# https://github.com/pytorch/pytorch/blob/79811e765c23242210ebdc623539d2103a166463/torch/testing/_creation.py#L38 -def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor: - r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with - values uniformly drawn from ``[low, high)``. - - If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable - finite values then they are clamped to the lowest or highest representable finite value, respectively. - If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`, - which depend on :attr:`dtype`. - - +---------------------------+------------+----------+ - | ``dtype`` | ``low`` | ``high`` | - +===========================+============+==========+ - | boolean type | ``0`` | ``2`` | - +---------------------------+------------+----------+ - | unsigned integral type | ``0`` | ``10`` | - +---------------------------+------------+----------+ - | signed integral types | ``-9`` | ``10`` | - +---------------------------+------------+----------+ - | floating types | ``-9`` | ``9`` | - +---------------------------+------------+----------+ - | complex types | ``-9`` | ``9`` | - +---------------------------+------------+----------+ - """ - contiguous = not noncontiguous - if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool) - elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype) - elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int - elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous) - else: raise NotImplementedError(f"{dtype} not implemented") - class TestIndexing(unittest.TestCase): def test_index(self): @@ -711,17 +680,15 @@ class TestIndexing(unittest.TestCase): numpy_testing_assert_equal_helper(out, Tensor.zeros(2)) ''' - # TODO argsort - ''' - def test_take_along_dim_invalid(self): + def test_gather_invalid(self): for dtype in (dtypes.int64, dtypes.float32): shape = (2, 3, 1, 4) - dim = 0 - t = make_tensor(shape, dtype=dtype) - indices = argsort(t, dim=dim) + t = (Tensor.randint(*shape, low=-9, high=10, dtype=dtype) if dtypes.is_int(dtype) + else Tensor.uniform(*shape, low=-9.0, high=9.0, dtype=dtype)) + indices = t.argsort(dim=0) # dim of `t` and `indices` does not match - with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"): + with self.assertRaises(RuntimeError): t.gather(0, indices[0]) # invalid `indices` dtype @@ -731,8 +698,9 @@ class TestIndexing(unittest.TestCase): with self.assertRaises(RuntimeError): t.gather(0, indices.cast(dtypes.float32)) - with self.assertRaises(RuntimeError): - t.gather(0, indices.cast(dtypes.int32)) + # torch requires int64 indices; tinygrad accepts any int dtype + # with self.assertRaises(RuntimeError): + # t.gather(0, indices.cast(dtypes.int32)) # invalid axis with self.assertRaises(IndexError): @@ -740,7 +708,6 @@ class TestIndexing(unittest.TestCase): with self.assertRaises(IndexError): t.gather(7, indices) - ''' class TestNumpy(unittest.TestCase): def test_empty_tuple_index(self): diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index f06c1d11ff..963a988853 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -923,7 +923,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): ``` """ if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}") - assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}" + if index.ndim != self.ndim: raise RuntimeError(f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}") dim = self._resolve_dim(dim) assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)