mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
scatter to mixin (#15917)
This commit is contained in:
@@ -164,6 +164,32 @@ class TestTensorUOpLoss(unittest.TestCase):
|
||||
self.assertIs(_strip_unique(t.sparse_categorical_crossentropy(Y, ignore_index=0).uop),
|
||||
_strip_unique(t.uop.sparse_categorical_crossentropy(Y.uop, ignore_index=0)))
|
||||
|
||||
class TestTensorUOpScatter(unittest.TestCase):
|
||||
def test_scatter(self):
|
||||
x, idx, src = _t(3, 4).float(), Tensor([[0, 1, 2, 0]], dtype=dtypes.int32), _t(1, 4).float()
|
||||
self.assertIs(_strip_unique(x.scatter(0, idx, src).uop), _strip_unique(x.uop.scatter(0, idx.uop, src.uop)))
|
||||
def test_scatter_scalar_src(self):
|
||||
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
||||
self.assertIs(_strip_unique(x.scatter(1, idx, 3.14).uop), _strip_unique(x.uop.scatter(1, idx.uop, 3.14)))
|
||||
# inf cannot be cast to int — this regresses if scalar src is routed through index.dtype first
|
||||
def test_scatter_inf_src(self):
|
||||
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
||||
self.assertIs(_strip_unique(x.scatter(1, idx, float("inf")).uop),
|
||||
_strip_unique(x.uop.scatter(1, idx.uop, float("inf"))))
|
||||
def test_scatter_add(self):
|
||||
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
||||
self.assertIs(_strip_unique(x.scatter(1, idx, 3.14, reduce="add").uop),
|
||||
_strip_unique(x.uop.scatter(1, idx.uop, 3.14, reduce="add")))
|
||||
def test_scatter_multiply(self):
|
||||
x, idx = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32)
|
||||
self.assertIs(_strip_unique(x.scatter(1, idx, 3.14, reduce="multiply").uop),
|
||||
_strip_unique(x.uop.scatter(1, idx.uop, 3.14, reduce="multiply")))
|
||||
# tensor src with reduce hits the "elif reduce: raise" branch in both Tensor and UOp paths
|
||||
def test_scatter_tensor_src_with_reduce_raises(self):
|
||||
x, idx, src = _t(3, 4).float(), Tensor([[0, 1]], dtype=dtypes.int32), _t(1, 2).float()
|
||||
with self.assertRaises(TypeError): x.scatter(1, idx, src, reduce="add")
|
||||
with self.assertRaises(TypeError): x.uop.scatter(1, idx.uop, src.uop, reduce="add")
|
||||
|
||||
class TestTensorUOpScatterReduce(unittest.TestCase):
|
||||
def _check(self, x, idx, src, **kw):
|
||||
self.assertIs(_strip_unique(x.scatter_reduce(0, idx, src, **kw).uop),
|
||||
|
||||
@@ -921,6 +921,49 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
||||
|
||||
def scatter(self, dim:int, index:Self, src:Self|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Self:
|
||||
"""
|
||||
Scatters `src` values along an axis specified by `dim`.
|
||||
Apply `add` or `multiply` reduction operation with `reduce`.
|
||||
|
||||
NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
src = Tensor.arange(1, 11).reshape(2, 5)
|
||||
print(src.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
index = Tensor([[0, 1, 2, 0]])
|
||||
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
index = Tensor([[0, 1, 2], [0, 1, 4]])
|
||||
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
|
||||
```
|
||||
"""
|
||||
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
|
||||
if isinstance(src, (int, float, bool)): src = type(self).full(index.shape, src, dtype=self.dtype, device=self.device)
|
||||
elif reduce: raise TypeError("non-scalar src is not supported with reduce arg. use scatter_reduce")
|
||||
if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True)
|
||||
if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True)
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
return self._masked_merge(src, mask, (-1,))
|
||||
|
||||
def _masked_merge(self, values:Self, mask:Self, axes:tuple[int, ...]) -> Self:
|
||||
# reduce such that if mask contains repeated indices the last one remains
|
||||
for dim in reversed(axes):
|
||||
mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
||||
# remove extra dims from reduce
|
||||
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
|
||||
# select from values for each True element in mask else select from self
|
||||
return mask.where(values, self)
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def sequential(self, ll:list[Callable[[Self], Self]]) -> Self:
|
||||
|
||||
@@ -77,15 +77,6 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
||||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
return ret
|
||||
|
||||
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
|
||||
# reduce such that if mask contains repeated indices the last one remains
|
||||
for dim in reversed(axes):
|
||||
mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
||||
# remove extra dims from reduce
|
||||
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
|
||||
# select from values for each True element in mask else select from target
|
||||
return mask.where(values, target)
|
||||
|
||||
class Tensor(OpMixin):
|
||||
"""
|
||||
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
||||
@@ -1063,7 +1054,7 @@ class Tensor(OpMixin):
|
||||
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
|
||||
for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum
|
||||
start = dims[0] if not permuted else 0
|
||||
vb = _masked_setitem(x_pre, vb, mask, tuple(range(start, start + len(big_shape))))
|
||||
vb = x_pre._masked_merge(vb, mask, tuple(range(start, start + len(big_shape))))
|
||||
elif v is None: return x # basic getitem
|
||||
# basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims)
|
||||
else: vb = v.cast(self.dtype)._broadcast_to(x.shape)
|
||||
@@ -1365,40 +1356,6 @@ class Tensor(OpMixin):
|
||||
if IMAGE: return self.image_dot(w, dtype)
|
||||
return super().dot(w, dtype)
|
||||
|
||||
def scatter(self, dim:int, index:Tensor, src:Tensor|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
|
||||
"""
|
||||
Scatters `src` values along an axis specified by `dim`.
|
||||
Apply `add` or `multiply` reduction operation with `reduce`.
|
||||
|
||||
NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
src = Tensor.arange(1, 11).reshape(2, 5)
|
||||
print(src.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
index = Tensor([[0, 1, 2, 0]])
|
||||
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
index = Tensor([[0, 1, 2], [0, 1, 4]])
|
||||
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
|
||||
```
|
||||
"""
|
||||
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
|
||||
if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
|
||||
if not isinstance(src, Tensor): src = index.full_like(src, device=self.device, dtype=self.dtype)
|
||||
if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True)
|
||||
if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True)
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
return _masked_setitem(self, src, mask, (-1,))
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
def contiguous(self, *args, **kwargs) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user