From 7a1adfd2aa8ce11de4d22d8bfbedda109976a24e Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 24 Apr 2026 08:27:17 -0400 Subject: [PATCH] update Tensor.allclose to return Tensor (#15904) matches jax --- test/backend/test_asm_gemm.py | 6 +++--- test/backend/test_custom_kernel.py | 2 +- test/backend/test_multitensor.py | 2 +- tinygrad/tensor.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/backend/test_asm_gemm.py b/test/backend/test_asm_gemm.py index 7d006a4122..518aaeeb3e 100644 --- a/test/backend/test_asm_gemm.py +++ b/test/backend/test_asm_gemm.py @@ -46,9 +46,9 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N np.testing.assert_allclose(tst.numpy(), ref.numpy(), atol=atol, rtol=rtol) np.testing.assert_allclose(a.grad.numpy(), a_ref.grad.numpy(), atol=grad_atol, rtol=grad_rtol) np.testing.assert_allclose(b.grad.numpy(), b_ref.grad.numpy(), atol=grad_atol, rtol=grad_rtol) - assert tst.allclose(ref, atol=atol, rtol=rtol), "forward mismatch" - assert a.grad.allclose(a_ref.grad, atol=grad_atol, rtol=grad_rtol), "grad_a mismatch" - assert b.grad.allclose(b_ref.grad, atol=grad_atol, rtol=grad_rtol), "grad_b mismatch" + assert tst.allclose(ref, atol=atol, rtol=rtol).item(), "forward mismatch" + assert a.grad.allclose(a_ref.grad, atol=grad_atol, rtol=grad_rtol).item(), "grad_a mismatch" + assert b.grad.allclose(b_ref.grad, atol=grad_atol, rtol=grad_rtol).item(), "grad_b mismatch" def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None: run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=0, b_shard=None, gpus=gpus) diff --git a/test/backend/test_custom_kernel.py b/test/backend/test_custom_kernel.py index ba4d834200..f2971086bf 100644 --- a/test/backend/test_custom_kernel.py +++ b/test/backend/test_custom_kernel.py @@ -189,7 +189,7 @@ class TestCustomKernel(unittest.TestCase): A = Tensor.randn(16, 16).contiguous() B = Tensor.empty(16) B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0] - self.assertTrue(B.allclose(A.sum(1))) + self.assertTrue(B.allclose(A.sum(1)).item()) def test_gemm(self): N = 16 diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index ceed23e0b1..2ea7de634d 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -666,7 +666,7 @@ class TestMultiTensor(unittest.TestCase): rng = Tensor.rand((10, 10, 10)) t0 = rng.shard(devices_2, axis=1) out = t0.flip(0) + 1 - self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device))) + self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)).item()) @unittest.skip("flaky") def test_reshape_on_axis(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 699aa1423e..bd67b6d19a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1203,11 +1203,11 @@ class Tensor(OpMixin): # ***** reduce ops ***** - def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> bool: + def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor: """ - Check if all self and other are close. Return True or False. + Check if all self and other are close. """ - return bool(self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all().item()) + return self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all() def keccak(self, cfg:str|tuple[int, int]="sha3_256"): """