fix wrong kwargs passed into rands (#16149)

working towards explicit args for these
This commit is contained in:
chenyu
2026-05-11 22:22:06 -04:00
committed by GitHub
parent 039d84ff02
commit 3942a80f66
3 changed files with 11 additions and 43 deletions

View File

@@ -165,7 +165,8 @@ def isin_tensor_tensor_out(x, y, *, assume_unique=False, invert=False, out=None)
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
def randperm_generator(n, generator=None, out=None):
return out.copy_(wrap(Tensor.randperm(n, generator=generator, device=unwrap(out).device)))
if generator is not None: raise NotImplementedError("tinygrad torch backend does not support torch.Generator for randperm")
return out.copy_(wrap(Tensor.randperm(n, device=unwrap(out).device)))
@torch.library.impl("aten::_linalg_eigh", "privateuseone")
# TODO: move to tinygrad

View File

@@ -56,37 +56,6 @@ def diagonal(tensor:Tensor) -> Tensor:
def unravel_index(tensor, shape):
pass
# https://github.com/pytorch/pytorch/blob/79811e765c23242210ebdc623539d2103a166463/torch/testing/_creation.py#L38
def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor:
r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
values uniformly drawn from ``[low, high)``.
If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
finite values then they are clamped to the lowest or highest representable finite value, respectively.
If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
which depend on :attr:`dtype`.
+---------------------------+------------+----------+
| ``dtype`` | ``low`` | ``high`` |
+===========================+============+==========+
| boolean type | ``0`` | ``2`` |
+---------------------------+------------+----------+
| unsigned integral type | ``0`` | ``10`` |
+---------------------------+------------+----------+
| signed integral types | ``-9`` | ``10`` |
+---------------------------+------------+----------+
| floating types | ``-9`` | ``9`` |
+---------------------------+------------+----------+
| complex types | ``-9`` | ``9`` |
+---------------------------+------------+----------+
"""
contiguous = not noncontiguous
if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype)
elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int
elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
else: raise NotImplementedError(f"{dtype} not implemented")
class TestIndexing(unittest.TestCase):
def test_index(self):
@@ -711,17 +680,15 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(out, Tensor.zeros(2))
'''
# TODO argsort
'''
def test_take_along_dim_invalid(self):
def test_gather_invalid(self):
for dtype in (dtypes.int64, dtypes.float32):
shape = (2, 3, 1, 4)
dim = 0
t = make_tensor(shape, dtype=dtype)
indices = argsort(t, dim=dim)
t = (Tensor.randint(*shape, low=-9, high=10, dtype=dtype) if dtypes.is_int(dtype)
else Tensor.uniform(*shape, low=-9.0, high=9.0, dtype=dtype))
indices = t.argsort(dim=0)
# dim of `t` and `indices` does not match
with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"):
with self.assertRaises(RuntimeError):
t.gather(0, indices[0])
# invalid `indices` dtype
@@ -731,8 +698,9 @@ class TestIndexing(unittest.TestCase):
with self.assertRaises(RuntimeError):
t.gather(0, indices.cast(dtypes.float32))
with self.assertRaises(RuntimeError):
t.gather(0, indices.cast(dtypes.int32))
# torch requires int64 indices; tinygrad accepts any int dtype
# with self.assertRaises(RuntimeError):
# t.gather(0, indices.cast(dtypes.int32))
# invalid axis
with self.assertRaises(IndexError):
@@ -740,7 +708,6 @@ class TestIndexing(unittest.TestCase):
with self.assertRaises(IndexError):
t.gather(7, indices)
'''
class TestNumpy(unittest.TestCase):
def test_empty_tuple_index(self):

View File

@@ -923,7 +923,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
```
"""
if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
if index.ndim != self.ndim: raise RuntimeError(f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}")
dim = self._resolve_dim(dim)
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)