From 7a1bfb668d5fa4301e114ff11788d5d5804d3269 Mon Sep 17 00:00:00 2001 From: Xingyu Date: Wed, 4 Jun 2025 19:59:50 +0800 Subject: [PATCH] Implement linalg_eigh function for tensor eigenvalue decomposition in torch backend (#10612) * Implement private _linalg_eigh function for tensor eigenvalue decomposition in torch backend * Add unit test for linalg.eigh function in TestTorchBackend This test verifies the eigenvalue decomposition of a 2x2 tensor using the linalg.eigh function, ensuring the computed eigenvalues and reconstructed tensor match the expected results. --- extra/torch_backend/backend.py | 6 ++++++ extra/torch_backend/test.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index a86b3d5477..c7f082cd2c 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -121,6 +121,12 @@ def cummax(self, dim): # TODO: move to tinygrad def nonzero(self): return aten.nonzero(self.cpu()).tiny() +@torch.library.impl("aten::_linalg_eigh", "privateuseone") +# TODO: move to tinygrad +def _linalg_eigh(self, UPLO: str = 'U'): + w, v = torch.linalg.eigh(self.cpu(), UPLO=UPLO) + return w.tiny(), v.tiny() + def upsample_backward(grad_out, output_size, input_size, *args, f=None): return f(grad_out.cpu(), output_size, input_size, *args).tiny() for i in [ diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 113c013c81..5d84226b00 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -170,6 +170,13 @@ class TestTorchBackend(unittest.TestCase): assert torch.equal(tensor_a, tensor_b) assert not torch.equal(tensor_a, tensor_c) + def test_linalg_eigh(self): + a = torch.tensor([[1, 2], [2, 1]], dtype=torch.float32, device=device) + w, v = torch.linalg.eigh(a) + np.testing.assert_equal(w.cpu().numpy(), [-1, 3]) + recon = (v @ torch.diag(w) @ v.T).cpu().numpy() + np.testing.assert_allclose(recon, a.cpu().numpy(), atol=1e-6) + def test_scalar_assign(self): a = torch.tensor([1, 2, 3], device=device) a[1] = 4