Torch backend aten::cat.out fix (#16121)

* Handle empty 1D tensors in cat_out

* Undid other changes

* Fixed torch cat

* Improved cat.out, added more tests

* Cleaned code

* Type hinted dim

* Removed whitespace
This commit is contained in:
Vikram Rangarajan
2026-05-11 19:28:16 -04:00
committed by GitHub
parent 63c1f00b80
commit effa263865
2 changed files with 26 additions and 2 deletions

View File

@@ -373,8 +373,12 @@ def copy_(self, src, non_blocking=False):
return self
@torch.library.impl("aten::cat.out", "privateuseone")
def cat_out(tensors, dim=0, out=None):
_apply_inplace(unwrap(out), Tensor.cat(*[unwrap(x) for x in tensors], dim=dim))
def cat_out(tensors: list[torch.Tensor], dim: int=0, *, out: torch.Tensor):
fixed_tensors = []
for wrapped in tensors:
if wrapped.shape == (0,): wrapped = wrapped.reshape([0 if i == (dim % out.ndim) else x for i, x in enumerate(out.shape)])
fixed_tensors.append(wrapped)
_apply_inplace(unwrap(out), Tensor.cat(*map(unwrap, fixed_tensors), dim=dim))
return out
@torch.library.impl("aten::topk.values", "privateuseone")

View File

@@ -808,6 +808,26 @@ class TestBackendHelpers(unittest.TestCase):
np.testing.assert_equal(out.cpu().numpy(), [1, 2, 3, 4])
assert ret is out
def test_cat_out_empty_1d(self):
# Test tiny and cpu to show test passes on torch cpu
for test_device in device, "cpu":
a = torch.tensor([], device=device)
b = torch.tensor([1, 2, 3, 4], device=device).reshape((2, 2))
out = torch.empty((2, 2), device=device)
for dim in 0, 1, -1, -2:
ret = torch.cat([a, b], out=out, dim=dim)
np.testing.assert_equal(out.cpu().numpy(), [[1, 2], [3, 4]])
assert ret is out
def test_cat_all_empty(self):
for test_device in device, "cpu":
a = torch.tensor([], device=device)
out = torch.empty((0,), device=device)
for dim in 0, -1:
ret = torch.cat([a, a], out=out, dim=dim)
np.testing.assert_equal(out.cpu().numpy(), [])
assert ret is out
def test_scatter_add_out(self):
src = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device, dtype=torch.float32)
index = torch.tensor([[0, 1, 2], [0, 1, 2]], device=device)