mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-16 01:48:24 +08:00
add masked_select to tensor.py (#9468)
* add masked_select to tensor.py * fix tests --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -64,11 +64,6 @@ def inplace_fn(outvars: str|list[str]):
|
||||
|
||||
# *** bad functions on CPU ***
|
||||
|
||||
@torch.library.impl("aten::masked_select", "privateuseone")
|
||||
def masked_select(self, mask):
|
||||
# err, bad
|
||||
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()], device=_from_torch_device(self.device)))
|
||||
|
||||
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
|
||||
@inplace_fn("self")
|
||||
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
@@ -418,6 +413,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.masked_fill_.Scalar": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))),
|
||||
"aten.masked_fill.Scalar": Tensor.masked_fill,
|
||||
"aten.masked_fill.Tensor": Tensor.masked_fill,
|
||||
"aten.masked_select": Tensor.masked_select,
|
||||
"aten.all": Tensor.all,
|
||||
"aten.sgn": Tensor.sign,
|
||||
"aten.acos": Tensor.acos,
|
||||
|
||||
@@ -98,6 +98,15 @@ class TestTorchBackend(unittest.TestCase):
|
||||
np.testing.assert_equal(out.values.cpu().numpy(), [4, 3])
|
||||
np.testing.assert_equal(out.indices.cpu().numpy(), [3, 1])
|
||||
|
||||
def test_masked_select(self):
|
||||
a = torch.tensor([4, 3, 2, 1], device=device)
|
||||
mask = torch.tensor([True, False, True, False], device=device)
|
||||
out = torch.masked_select(a, mask)
|
||||
np.testing.assert_equal(out.cpu().numpy(), [4, 2])
|
||||
mask = torch.tensor(True, device=device)
|
||||
out = torch.masked_select(a, mask)
|
||||
np.testing.assert_equal(out.cpu().numpy(), [4, 3, 2, 1])
|
||||
|
||||
@unittest.skip("meh")
|
||||
def test_str(self):
|
||||
a = torch.ones(4, device=device)
|
||||
|
||||
@@ -2883,6 +2883,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
|
||||
|
||||
def test_masked_select(self):
|
||||
helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True)
|
||||
helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
|
||||
def test_cast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.float())
|
||||
|
||||
@@ -1556,6 +1556,27 @@ class Tensor(SimpleMathTrait):
|
||||
t = t.permute([lhs.index(name) for name in rhs])
|
||||
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
||||
|
||||
def masked_select(self, mask):
|
||||
"""
|
||||
Selects elements from `self` based on the boolean `mask`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
|
||||
mask = Tensor([[True, False, True], [False, True, False], [False, False, True]])
|
||||
print(t.numpy())
|
||||
print(mask.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.masked_select(mask).numpy())
|
||||
```
|
||||
"""
|
||||
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
|
||||
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
||||
mask_cumsum = mask.cumsum()
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item(), dtype=dtypes.int32)
|
||||
idxs = counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()
|
||||
return x[idxs]
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user