fix empty input for masked_select and nonzero (#16168)

This commit is contained in:
chenyu
2026-05-12 15:36:51 -04:00
committed by GitHub
parent da3b7e89a4
commit 7c3e3fa154
3 changed files with 4 additions and 2 deletions

View File

@@ -91,7 +91,6 @@ class TestEmptyTensorEdgeCases(unittest.TestCase):
with self.assertRaises(RuntimeError):
Tensor([]).argmax()
@unittest.expectedFailure
def test_masked_select_empty(self):
# Masked select on empty tensors should return an empty tensor.
torch_out = torch.tensor([], dtype=torch.float32).masked_select(torch.tensor([], dtype=torch.bool))

View File

@@ -3334,6 +3334,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(32, 10)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True)
helper_test_op([(20,)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True)
helper_test_op([(10, 5, 3)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True)
for v in (0, 1, 0.0, 2.5, True, False):
helper_test_op(None, lambda x: x.nonzero().int(), lambda x: x.nonzero(), vals=[v], forward_only=True)
def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.float())

View File

@@ -1074,7 +1074,7 @@ class Tensor(OpMixin):
if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}")
x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten()
mask_cumsum = mask.cumsum()
counts = Tensor.zeros(mask_cumsum[-1].item(), dtype=dtypes.int32, device=self.device)
counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device)
idxs = counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum()
return x[idxs]
@@ -1099,6 +1099,7 @@ class Tensor(OpMixin):
print(t.nonzero().numpy())
```
"""
if self.ndim == 0: return Tensor.zeros(int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device)
mask = (self != 0).flatten()
indices = Tensor.stack(*[Tensor.arange(s, device=self.device).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten()
for i, s in enumerate(self.shape)], dim=-1)