From 2f9fdb4a375d3c9efa271acab6b6592bfee628c7 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 24 Apr 2026 15:37:37 -0400 Subject: [PATCH] scatter to mixin (#15917) --- test/null/test_tensor_uop_mixin.py | 26 +++++++++++++++++ tinygrad/mixin/__init__.py | 43 ++++++++++++++++++++++++++++ tinygrad/tensor.py | 45 +----------------------------- 3 files changed, 70 insertions(+), 44 deletions(-) diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index c917573d87..844260a64b 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -164,6 +164,32 @@ class TestTensorUOpLoss(unittest.TestCase): self.assertIs(_strip_unique(t.sparse_categorical_crossentropy(Y, ignore_index=0).uop), _strip_unique(t.uop.sparse_categorical_crossentropy(Y.uop, ignore_index=0))) +class TestTensorUOpScatter(unittest.TestCase): + def test_scatter(self): + x, idx, src = _t(3, 4).float(), Tensor([[0, 1, 2, 0]], dtype=dtypes.int32), _t(1, 4).float() + self.assertIs(_strip_unique(x.scatter(0, idx, src).uop), _strip_unique(x.uop.scatter(0, idx.uop, src.uop))) + def test_scatter_scalar_src(self): + x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32) + self.assertIs(_strip_unique(x.scatter(1, idx, 3.14).uop), _strip_unique(x.uop.scatter(1, idx.uop, 3.14))) + # inf cannot be cast to int — this regresses if scalar src is routed through index.dtype first + def test_scatter_inf_src(self): + x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32) + self.assertIs(_strip_unique(x.scatter(1, idx, float("inf")).uop), + _strip_unique(x.uop.scatter(1, idx.uop, float("inf")))) + def test_scatter_add(self): + x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32) + self.assertIs(_strip_unique(x.scatter(1, idx, 3.14, reduce="add").uop), + _strip_unique(x.uop.scatter(1, idx.uop, 3.14, reduce="add"))) + def test_scatter_multiply(self): + x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32) + self.assertIs(_strip_unique(x.scatter(1, idx, 3.14, reduce="multiply").uop), + _strip_unique(x.uop.scatter(1, idx.uop, 3.14, reduce="multiply"))) + # tensor src with reduce hits the "elif reduce: raise" branch in both Tensor and UOp paths + def test_scatter_tensor_src_with_reduce_raises(self): + x, idx, src = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32), _t(1, 2).float() + with self.assertRaises(TypeError): x.scatter(1, idx, src, reduce="add") + with self.assertRaises(TypeError): x.uop.scatter(1, idx.uop, src.uop, reduce="add") + class TestTensorUOpScatterReduce(unittest.TestCase): def _check(self, x, idx, src, **kw): self.assertIs(_strip_unique(x.scatter_reduce(0, idx, src, **kw).uop), diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 1ff335d5d3..0e408bf501 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -921,6 +921,49 @@ class OpMixin(ElementwiseMixin, ReduceMixin): return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count) raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") + def scatter(self, dim:int, index:Self, src:Self|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Self: + """ + Scatters `src` values along an axis specified by `dim`. + Apply `add` or `multiply` reduction operation with `reduce`. + + NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`. + + ```python exec="true" source="above" session="tensor" result="python" + src = Tensor.arange(1, 11).reshape(2, 5) + print(src.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + index = Tensor([[0, 1, 2, 0]]) + print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + index = Tensor([[0, 1, 2], [0, 1, 4]]) + print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy()) + ``` + """ + if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'") + if isinstance(src, (int, float, bool)): src = type(self).full(index.shape, src, dtype=self.dtype, device=self.device) + elif reduce: raise TypeError("non-scalar src is not supported with reduce arg. use scatter_reduce") + if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True) + if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True) + src, mask = self._pre_scatter(dim, index, src) + return self._masked_merge(src, mask, (-1,)) + + def _masked_merge(self, values:Self, mask:Self, axes:tuple[int, ...]) -> Self: + # reduce such that if mask contains repeated indices the last one remains + for dim in reversed(axes): + mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) + # remove extra dims from reduce + for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim) + # select from values for each True element in mask else select from self + return mask.where(values, self) + # ***** functional nn ops ***** def sequential(self, ll:list[Callable[[Self], Self]]) -> Self: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3ce56a8dca..861fe928a9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -77,15 +77,6 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor: assert isinstance(ret, Tensor), "sum didn't return a Tensor" return ret -def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor: - # reduce such that if mask contains repeated indices the last one remains - for dim in reversed(axes): - mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) - # remove extra dims from reduce - for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim) - # select from values for each True element in mask else select from target - return mask.where(values, target) - class Tensor(OpMixin): """ A `Tensor` is a multi-dimensional matrix containing elements of a single data type. @@ -1063,7 +1054,7 @@ class Tensor(OpMixin): vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape)) for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum start = dims[0] if not permuted else 0 - vb = _masked_setitem(x_pre, vb, mask, tuple(range(start, start + len(big_shape)))) + vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape)))) elif v is None: return x # basic getitem # basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims) else: vb = v.cast(self.dtype)._broadcast_to(x.shape) @@ -1365,40 +1356,6 @@ class Tensor(OpMixin): if IMAGE: return self.image_dot(w, dtype) return super().dot(w, dtype) - def scatter(self, dim:int, index:Tensor, src:Tensor|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Tensor: - """ - Scatters `src` values along an axis specified by `dim`. - Apply `add` or `multiply` reduction operation with `reduce`. - - NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`. - - ```python exec="true" source="above" session="tensor" result="python" - src = Tensor.arange(1, 11).reshape(2, 5) - print(src.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - index = Tensor([[0, 1, 2, 0]]) - print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - index = Tensor([[0, 1, 2], [0, 1, 4]]) - print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy()) - ``` - """ - if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'") - if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce") - if not isinstance(src, Tensor): src = index.full_like(src, device=self.device, dtype=self.dtype) - if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True) - if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True) - src, mask = self._pre_scatter(dim, index, src) - return _masked_setitem(self, src, mask, (-1,)) - # ***** unary ops ***** def contiguous(self, *args, **kwargs) -> Tensor: