gather to mixin (#15891)

This commit is contained in:
chenyu
2026-04-23 14:00:57 -04:00
committed by GitHub
parent 87223f870e
commit f0dbc68aa9
3 changed files with 26 additions and 19 deletions

View File

@@ -99,6 +99,13 @@ class TestTensorUOpOneHot(unittest.TestCase):
t = _t(5)
self.assertIs(_strip_unique(t.one_hot(5).uop), _strip_unique(t.uop.one_hot(5)))
class TestTensorUOpGather(unittest.TestCase):
def _check(self, t, dim, idx):
self.assertIs(_strip_unique(t.gather(dim, idx).uop), _strip_unique(t.uop.gather(dim, idx.uop)))
def test_gather_1d(self): self._check(_t(5), 0, Tensor([2, 1, 0, 1, 2], dtype=dtypes.int32))
def test_gather_dim0(self): self._check(_t(3, 4), 0, Tensor([[0, 1, 2, 0], [1, 2, 0, 1], [2, 0, 1, 2]], dtype=dtypes.int32))
def test_gather_dim1(self): self._check(_t(3, 4), 1, Tensor([[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1]], dtype=dtypes.int32))
class TestTensorUOpCat(unittest.TestCase):
def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0))
def test_cat_dim1(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=1))

View File

@@ -721,6 +721,25 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if num_classes < 0: raise ValueError(f"num_classes must be non-negative, got {num_classes}")
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
def gather(self, dim:int, index:Self) -> Self:
"""
Gathers values along an axis specified by `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
```
"""
if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
dim = self._resolve_dim(dim)
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
# ***** functional nn ops *****
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:

View File

@@ -1154,25 +1154,6 @@ class Tensor(OpMixin):
def __delitem__(self, indices) -> None:
raise TypeError("Tensor does not support deleting items")
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
"""
Gathers values along an axis specified by `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
```
"""
if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}")
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
dim = self._resolve_dim(dim)
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
def masked_select(self, mask):
"""
Selects elements from `self` based on the boolean `mask`.