mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
gather to mixin (#15891)
This commit is contained in:
@@ -99,6 +99,13 @@ class TestTensorUOpOneHot(unittest.TestCase):
|
||||
t = _t(5)
|
||||
self.assertIs(_strip_unique(t.one_hot(5).uop), _strip_unique(t.uop.one_hot(5)))
|
||||
|
||||
class TestTensorUOpGather(unittest.TestCase):
|
||||
def _check(self, t, dim, idx):
|
||||
self.assertIs(_strip_unique(t.gather(dim, idx).uop), _strip_unique(t.uop.gather(dim, idx.uop)))
|
||||
def test_gather_1d(self): self._check(_t(5), 0, Tensor([2, 1, 0, 1, 2], dtype=dtypes.int32))
|
||||
def test_gather_dim0(self): self._check(_t(3, 4), 0, Tensor([[0, 1, 2, 0], [1, 2, 0, 1], [2, 0, 1, 2]], dtype=dtypes.int32))
|
||||
def test_gather_dim1(self): self._check(_t(3, 4), 1, Tensor([[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1]], dtype=dtypes.int32))
|
||||
|
||||
class TestTensorUOpCat(unittest.TestCase):
|
||||
def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0))
|
||||
def test_cat_dim1(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=1))
|
||||
|
||||
@@ -721,6 +721,25 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
if num_classes < 0: raise ValueError(f"num_classes must be non-negative, got {num_classes}")
|
||||
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
||||
|
||||
def gather(self, dim:int, index:Self) -> Self:
|
||||
"""
|
||||
Gathers values along an axis specified by `dim`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2], [3, 4]])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
|
||||
```
|
||||
"""
|
||||
if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
|
||||
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
||||
dim = self._resolve_dim(dim)
|
||||
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
||||
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:
|
||||
|
||||
@@ -1154,25 +1154,6 @@ class Tensor(OpMixin):
|
||||
def __delitem__(self, indices) -> None:
|
||||
raise TypeError("Tensor does not support deleting items")
|
||||
|
||||
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
|
||||
"""
|
||||
Gathers values along an axis specified by `dim`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2], [3, 4]])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
|
||||
```
|
||||
"""
|
||||
if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
|
||||
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
|
||||
dim = self._resolve_dim(dim)
|
||||
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
||||
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
|
||||
|
||||
def masked_select(self, mask):
|
||||
"""
|
||||
Selects elements from `self` based on the boolean `mask`.
|
||||
|
||||
Reference in New Issue
Block a user