mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
fix wrong kwargs passed into rands (#16149)
working towards explicit args for these
This commit is contained in:
@@ -165,7 +165,8 @@ def isin_tensor_tensor_out(x, y, *, assume_unique=False, invert=False, out=None)
|
||||
|
||||
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
|
||||
def randperm_generator(n, generator=None, out=None):
|
||||
return out.copy_(wrap(Tensor.randperm(n, generator=generator, device=unwrap(out).device)))
|
||||
if generator is not None: raise NotImplementedError("tinygrad torch backend does not support torch.Generator for randperm")
|
||||
return out.copy_(wrap(Tensor.randperm(n, device=unwrap(out).device)))
|
||||
|
||||
@torch.library.impl("aten::_linalg_eigh", "privateuseone")
|
||||
# TODO: move to tinygrad
|
||||
|
||||
@@ -56,37 +56,6 @@ def diagonal(tensor:Tensor) -> Tensor:
|
||||
def unravel_index(tensor, shape):
|
||||
pass
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/79811e765c23242210ebdc623539d2103a166463/torch/testing/_creation.py#L38
|
||||
def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor:
|
||||
r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
|
||||
values uniformly drawn from ``[low, high)``.
|
||||
|
||||
If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
|
||||
finite values then they are clamped to the lowest or highest representable finite value, respectively.
|
||||
If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
|
||||
which depend on :attr:`dtype`.
|
||||
|
||||
+---------------------------+------------+----------+
|
||||
| ``dtype`` | ``low`` | ``high`` |
|
||||
+===========================+============+==========+
|
||||
| boolean type | ``0`` | ``2`` |
|
||||
+---------------------------+------------+----------+
|
||||
| unsigned integral type | ``0`` | ``10`` |
|
||||
+---------------------------+------------+----------+
|
||||
| signed integral types | ``-9`` | ``10`` |
|
||||
+---------------------------+------------+----------+
|
||||
| floating types | ``-9`` | ``9`` |
|
||||
+---------------------------+------------+----------+
|
||||
| complex types | ``-9`` | ``9`` |
|
||||
+---------------------------+------------+----------+
|
||||
"""
|
||||
contiguous = not noncontiguous
|
||||
if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
|
||||
elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype)
|
||||
elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int
|
||||
elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
|
||||
else: raise NotImplementedError(f"{dtype} not implemented")
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def test_index(self):
|
||||
|
||||
@@ -711,17 +680,15 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(out, Tensor.zeros(2))
|
||||
'''
|
||||
|
||||
# TODO argsort
|
||||
'''
|
||||
def test_take_along_dim_invalid(self):
|
||||
def test_gather_invalid(self):
|
||||
for dtype in (dtypes.int64, dtypes.float32):
|
||||
shape = (2, 3, 1, 4)
|
||||
dim = 0
|
||||
t = make_tensor(shape, dtype=dtype)
|
||||
indices = argsort(t, dim=dim)
|
||||
t = (Tensor.randint(*shape, low=-9, high=10, dtype=dtype) if dtypes.is_int(dtype)
|
||||
else Tensor.uniform(*shape, low=-9.0, high=9.0, dtype=dtype))
|
||||
indices = t.argsort(dim=0)
|
||||
|
||||
# dim of `t` and `indices` does not match
|
||||
with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"):
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.gather(0, indices[0])
|
||||
|
||||
# invalid `indices` dtype
|
||||
@@ -731,8 +698,9 @@ class TestIndexing(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.gather(0, indices.cast(dtypes.float32))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.gather(0, indices.cast(dtypes.int32))
|
||||
# torch requires int64 indices; tinygrad accepts any int dtype
|
||||
# with self.assertRaises(RuntimeError):
|
||||
# t.gather(0, indices.cast(dtypes.int32))
|
||||
|
||||
# invalid axis
|
||||
with self.assertRaises(IndexError):
|
||||
@@ -740,7 +708,6 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
t.gather(7, indices)
|
||||
'''
|
||||
|
||||
class TestNumpy(unittest.TestCase):
|
||||
def test_empty_tuple_index(self):
|
||||
|
||||
@@ -923,7 +923,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
```
|
||||
"""
|
||||
if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
|
||||
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
||||
if index.ndim != self.ndim: raise RuntimeError(f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}")
|
||||
dim = self._resolve_dim(dim)
|
||||
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
||||
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
|
||||
Reference in New Issue
Block a user