diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 2bbc87b1d4..f6b596056d 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -373,8 +373,12 @@ def copy_(self, src, non_blocking=False): return self @torch.library.impl("aten::cat.out", "privateuseone") -def cat_out(tensors, dim=0, out=None): - _apply_inplace(unwrap(out), Tensor.cat(*[unwrap(x) for x in tensors], dim=dim)) +def cat_out(tensors: list[torch.Tensor], dim: int=0, *, out: torch.Tensor): + fixed_tensors = [] + for wrapped in tensors: + if wrapped.shape == (0,): wrapped = wrapped.reshape([0 if i == (dim % out.ndim) else x for i, x in enumerate(out.shape)]) + fixed_tensors.append(wrapped) + _apply_inplace(unwrap(out), Tensor.cat(*map(unwrap, fixed_tensors), dim=dim)) return out @torch.library.impl("aten::topk.values", "privateuseone") diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index b56e126d8e..ef3043569a 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -808,6 +808,26 @@ class TestBackendHelpers(unittest.TestCase): np.testing.assert_equal(out.cpu().numpy(), [1, 2, 3, 4]) assert ret is out + def test_cat_out_empty_1d(self): + # Test tiny and cpu to show test passes on torch cpu + for test_device in device, "cpu": + a = torch.tensor([], device=device) + b = torch.tensor([1, 2, 3, 4], device=device).reshape((2, 2)) + out = torch.empty((2, 2), device=device) + for dim in 0, 1, -1, -2: + ret = torch.cat([a, b], out=out, dim=dim) + np.testing.assert_equal(out.cpu().numpy(), [[1, 2], [3, 4]]) + assert ret is out + + def test_cat_all_empty(self): + for test_device in device, "cpu": + a = torch.tensor([], device=device) + out = torch.empty((0,), device=device) + for dim in 0, -1: + ret = torch.cat([a, a], out=out, dim=dim) + np.testing.assert_equal(out.cpu().numpy(), []) + assert ret is out + def test_scatter_add_out(self): src = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device, dtype=torch.float32) index = torch.tensor([[0, 1, 2], [0, 1, 2]], device=device)