mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Torch backend aten::cat.out fix (#16121)
* Handle empty 1D tensors in cat_out * Undid other changes * Fixed torch cat * Improved cat.out, added more tests * Cleaned code * Type hinted dim * Removed whitespace
This commit is contained in:
committed by
GitHub
parent
63c1f00b80
commit
effa263865
@@ -373,8 +373,12 @@ def copy_(self, src, non_blocking=False):
|
||||
return self
|
||||
|
||||
@torch.library.impl("aten::cat.out", "privateuseone")
|
||||
def cat_out(tensors, dim=0, out=None):
|
||||
_apply_inplace(unwrap(out), Tensor.cat(*[unwrap(x) for x in tensors], dim=dim))
|
||||
def cat_out(tensors: list[torch.Tensor], dim: int=0, *, out: torch.Tensor):
|
||||
fixed_tensors = []
|
||||
for wrapped in tensors:
|
||||
if wrapped.shape == (0,): wrapped = wrapped.reshape([0 if i == (dim % out.ndim) else x for i, x in enumerate(out.shape)])
|
||||
fixed_tensors.append(wrapped)
|
||||
_apply_inplace(unwrap(out), Tensor.cat(*map(unwrap, fixed_tensors), dim=dim))
|
||||
return out
|
||||
|
||||
@torch.library.impl("aten::topk.values", "privateuseone")
|
||||
|
||||
@@ -808,6 +808,26 @@ class TestBackendHelpers(unittest.TestCase):
|
||||
np.testing.assert_equal(out.cpu().numpy(), [1, 2, 3, 4])
|
||||
assert ret is out
|
||||
|
||||
def test_cat_out_empty_1d(self):
|
||||
# Test tiny and cpu to show test passes on torch cpu
|
||||
for test_device in device, "cpu":
|
||||
a = torch.tensor([], device=device)
|
||||
b = torch.tensor([1, 2, 3, 4], device=device).reshape((2, 2))
|
||||
out = torch.empty((2, 2), device=device)
|
||||
for dim in 0, 1, -1, -2:
|
||||
ret = torch.cat([a, b], out=out, dim=dim)
|
||||
np.testing.assert_equal(out.cpu().numpy(), [[1, 2], [3, 4]])
|
||||
assert ret is out
|
||||
|
||||
def test_cat_all_empty(self):
|
||||
for test_device in device, "cpu":
|
||||
a = torch.tensor([], device=device)
|
||||
out = torch.empty((0,), device=device)
|
||||
for dim in 0, -1:
|
||||
ret = torch.cat([a, a], out=out, dim=dim)
|
||||
np.testing.assert_equal(out.cpu().numpy(), [])
|
||||
assert ret is out
|
||||
|
||||
def test_scatter_add_out(self):
|
||||
src = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device, dtype=torch.float32)
|
||||
index = torch.tensor([[0, 1, 2], [0, 1, 2]], device=device)
|
||||
|
||||
Reference in New Issue
Block a user