mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix empty input for masked_select and nonzero (#16168)
This commit is contained in:
@@ -91,7 +91,6 @@ class TestEmptyTensorEdgeCases(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor([]).argmax()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_masked_select_empty(self):
|
||||
# Masked select on empty tensors should return an empty tensor.
|
||||
torch_out = torch.tensor([], dtype=torch.float32).masked_select(torch.tensor([], dtype=torch.bool))
|
||||
|
||||
@@ -3334,6 +3334,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32, 10)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True)
|
||||
helper_test_op([(20,)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True)
|
||||
helper_test_op([(10, 5, 3)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True)
|
||||
for v in (0, 1, 0.0, 2.5, True, False):
|
||||
helper_test_op(None, lambda x: x.nonzero().int(), lambda x: x.nonzero(), vals=[v], forward_only=True)
|
||||
|
||||
def test_cast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.float())
|
||||
|
||||
@@ -1074,7 +1074,7 @@ class Tensor(OpMixin):
|
||||
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
|
||||
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
|
||||
mask_cumsum = mask.cumsum()
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item(), dtype=dtypes.int32, device=self.device)
|
||||
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device)
|
||||
idxs = counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()
|
||||
return x[idxs]
|
||||
|
||||
@@ -1099,6 +1099,7 @@ class Tensor(OpMixin):
|
||||
print(t.nonzero().numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim == 0: return Tensor.zeros(int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device)
|
||||
mask = (self != 0).flatten()
|
||||
indices = Tensor.stack(*[Tensor.arange(s, device=self.device).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
|
||||
for i, s in enumerate(self.shape)], dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user