update Tensor.allclose to return Tensor (#15904)

matches jax
This commit is contained in:
chenyu
2026-04-24 08:27:17 -04:00
committed by GitHub
parent 48d7ab2695
commit 7a1adfd2aa
4 changed files with 8 additions and 8 deletions

View File

@@ -46,9 +46,9 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N
np.testing.assert_allclose(tst.numpy(), ref.numpy(), atol=atol, rtol=rtol)
np.testing.assert_allclose(a.grad.numpy(), a_ref.grad.numpy(), atol=grad_atol, rtol=grad_rtol)
np.testing.assert_allclose(b.grad.numpy(), b_ref.grad.numpy(), atol=grad_atol, rtol=grad_rtol)
assert tst.allclose(ref, atol=atol, rtol=rtol), "forward mismatch"
assert a.grad.allclose(a_ref.grad, atol=grad_atol, rtol=grad_rtol), "grad_a mismatch"
assert b.grad.allclose(b_ref.grad, atol=grad_atol, rtol=grad_rtol), "grad_b mismatch"
assert tst.allclose(ref, atol=atol, rtol=rtol).item(), "forward mismatch"
assert a.grad.allclose(a_ref.grad, atol=grad_atol, rtol=grad_rtol).item(), "grad_a mismatch"
assert b.grad.allclose(b_ref.grad, atol=grad_atol, rtol=grad_rtol).item(), "grad_b mismatch"
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None:
run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=0, b_shard=None, gpus=gpus)

View File

@@ -189,7 +189,7 @@ class TestCustomKernel(unittest.TestCase):
A = Tensor.randn(16, 16).contiguous()
B = Tensor.empty(16)
B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0]
self.assertTrue(B.allclose(A.sum(1)))
self.assertTrue(B.allclose(A.sum(1)).item())
def test_gemm(self):
N = 16

View File

@@ -666,7 +666,7 @@ class TestMultiTensor(unittest.TestCase):
rng = Tensor.rand((10, 10, 10))
t0 = rng.shard(devices_2, axis=1)
out = t0.flip(0) + 1
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)))
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)).item())
@unittest.skip("flaky")
def test_reshape_on_axis(self):

View File

@@ -1203,11 +1203,11 @@ class Tensor(OpMixin):
# ***** reduce ops *****
def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> bool:
def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor:
"""
Check if all self and other are close. Return True or False.
Check if all self and other are close.
"""
return bool(self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all().item())
return self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all()
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
"""