From f0dbc68aa9a331a1fe7afa1392ee7d8dd0928eb4 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 23 Apr 2026 14:00:57 -0400 Subject: [PATCH] gather to mixin (#15891) --- test/null/test_tensor_uop_mixin.py | 7 +++++++ tinygrad/mixin/__init__.py | 19 +++++++++++++++++++ tinygrad/tensor.py | 19 ------------------- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 7961c95b46..5ee23f4802 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -99,6 +99,13 @@ class TestTensorUOpOneHot(unittest.TestCase): t = _t(5) self.assertIs(_strip_unique(t.one_hot(5).uop), _strip_unique(t.uop.one_hot(5))) +class TestTensorUOpGather(unittest.TestCase): + def _check(self, t, dim, idx): + self.assertIs(_strip_unique(t.gather(dim, idx).uop), _strip_unique(t.uop.gather(dim, idx.uop))) + def test_gather_1d(self): self._check(_t(5), 0, Tensor([2, 1, 0, 1, 2], dtype=dtypes.int32)) + def test_gather_dim0(self): self._check(_t(3, 4), 0, Tensor([[0, 1, 2, 0], [1, 2, 0, 1], [2, 0, 1, 2]], dtype=dtypes.int32)) + def test_gather_dim1(self): self._check(_t(3, 4), 1, Tensor([[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1]], dtype=dtypes.int32)) + class TestTensorUOpCat(unittest.TestCase): def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0)) def test_cat_dim1(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=1)) diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index c6155c56b4..01a2bbef67 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -721,6 +721,25 @@ class OpMixin(ElementwiseMixin, ReduceMixin): if num_classes < 0: raise ValueError(f"num_classes must be non-negative, got {num_classes}") return self[..., None]._one_hot_along_dim(num_classes).where(1, 0) + def gather(self, dim:int, index:Self) -> Self: + """ + Gathers values along an axis specified by `dim`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[1, 2], [3, 4]]) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy()) + ``` + """ + if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}") + assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}" + dim = self._resolve_dim(dim) + assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" + x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) + return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype) + # ***** functional nn ops ***** def sequential(self, ll:list[Callable[[Self], Self]]) -> Self: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index db195f0b1f..c8d3a7455e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1154,25 +1154,6 @@ class Tensor(OpMixin): def __delitem__(self, indices) -> None: raise TypeError("Tensor does not support deleting items") - def gather(self:Tensor, dim:int, index:Tensor) -> Tensor: - """ - Gathers values along an axis specified by `dim`. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[1, 2], [3, 4]]) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy()) - ``` - """ - if index.device != self.device: raise RuntimeError(f"expected index and self on the same device, {index.device=}, {self.device=}") - assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}" - dim = self._resolve_dim(dim) - assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" - x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) - return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype) - def masked_select(self, mask): """ Selects elements from `self` based on the boolean `mask`.