From 7c3e3fa1542887f336fd0a9cc62efd6ef507cbcf Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 12 May 2026 15:36:51 -0400 Subject: [PATCH] fix empty input for masked_select and nonzero (#16168) --- test/backend/test_edgecases.py | 1 - test/backend/test_ops.py | 2 ++ tinygrad/tensor.py | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/backend/test_edgecases.py b/test/backend/test_edgecases.py index 582b1ab7b6..53b56392b1 100644 --- a/test/backend/test_edgecases.py +++ b/test/backend/test_edgecases.py @@ -91,7 +91,6 @@ class TestEmptyTensorEdgeCases(unittest.TestCase): with self.assertRaises(RuntimeError): Tensor([]).argmax() - @unittest.expectedFailure def test_masked_select_empty(self): # Masked select on empty tensors should return an empty tensor. torch_out = torch.tensor([], dtype=torch.float32).masked_select(torch.tensor([], dtype=torch.bool)) diff --git a/test/backend/test_ops.py b/test/backend/test_ops.py index 93b71efadc..2ac363c878 100644 --- a/test/backend/test_ops.py +++ b/test/backend/test_ops.py @@ -3334,6 +3334,8 @@ class TestOps(unittest.TestCase): helper_test_op([(32, 10)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True) helper_test_op([(20,)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True) helper_test_op([(10, 5, 3)], lambda x: (x>0.5).nonzero().int(), lambda x: (x>0.5).nonzero(), forward_only=True) + for v in (0, 1, 0.0, 2.5, True, False): + helper_test_op(None, lambda x: x.nonzero().int(), lambda x: x.nonzero(), vals=[v], forward_only=True) def test_cast(self): helper_test_op([(3, 3)], lambda x: x.float()) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 83cb35a3db..21cd8d4745 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1074,7 +1074,7 @@ class Tensor(OpMixin): if not dtypes.is_bool(mask.dtype): raise RuntimeError(f"masked_select expects bool mask tensor, got {mask.dtype}") x, mask = self.flatten(), mask._broadcast_to(self.shape).flatten() mask_cumsum = mask.cumsum() - counts = Tensor.zeros(mask_cumsum[-1].item(), dtype=dtypes.int32, device=self.device) + counts = Tensor.zeros(mask_cumsum[-1].item() if mask.numel() else 0, dtype=dtypes.int32, device=self.device) idxs = counts.scatter(0, mask_cumsum, 1, reduce='add').cumsum() return x[idxs] @@ -1099,6 +1099,7 @@ class Tensor(OpMixin): print(t.nonzero().numpy()) ``` """ + if self.ndim == 0: return Tensor.zeros(int((self != 0).item()), 0, dtype=dtypes.int32, device=self.device) mask = (self != 0).flatten() indices = Tensor.stack(*[Tensor.arange(s, device=self.device).reshape(*[1]*i, s, *[1]*(self.ndim-i-1)).expand(self.shape).flatten() for i, s in enumerate(self.shape)], dim=-1)