mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
@@ -46,9 +46,9 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N
|
||||
np.testing.assert_allclose(tst.numpy(), ref.numpy(), atol=atol, rtol=rtol)
|
||||
np.testing.assert_allclose(a.grad.numpy(), a_ref.grad.numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
np.testing.assert_allclose(b.grad.numpy(), b_ref.grad.numpy(), atol=grad_atol, rtol=grad_rtol)
|
||||
assert tst.allclose(ref, atol=atol, rtol=rtol), "forward mismatch"
|
||||
assert a.grad.allclose(a_ref.grad, atol=grad_atol, rtol=grad_rtol), "grad_a mismatch"
|
||||
assert b.grad.allclose(b_ref.grad, atol=grad_atol, rtol=grad_rtol), "grad_b mismatch"
|
||||
assert tst.allclose(ref, atol=atol, rtol=rtol).item(), "forward mismatch"
|
||||
assert a.grad.allclose(a_ref.grad, atol=grad_atol, rtol=grad_rtol).item(), "grad_a mismatch"
|
||||
assert b.grad.allclose(b_ref.grad, atol=grad_atol, rtol=grad_rtol).item(), "grad_b mismatch"
|
||||
|
||||
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None:
|
||||
run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=0, b_shard=None, gpus=gpus)
|
||||
|
||||
@@ -189,7 +189,7 @@ class TestCustomKernel(unittest.TestCase):
|
||||
A = Tensor.randn(16, 16).contiguous()
|
||||
B = Tensor.empty(16)
|
||||
B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0]
|
||||
self.assertTrue(B.allclose(A.sum(1)))
|
||||
self.assertTrue(B.allclose(A.sum(1)).item())
|
||||
|
||||
def test_gemm(self):
|
||||
N = 16
|
||||
|
||||
@@ -666,7 +666,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
rng = Tensor.rand((10, 10, 10))
|
||||
t0 = rng.shard(devices_2, axis=1)
|
||||
out = t0.flip(0) + 1
|
||||
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)))
|
||||
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)).item())
|
||||
|
||||
@unittest.skip("flaky")
|
||||
def test_reshape_on_axis(self):
|
||||
|
||||
@@ -1203,11 +1203,11 @@ class Tensor(OpMixin):
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> bool:
|
||||
def allclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor:
|
||||
"""
|
||||
Check if all self and other are close. Return True or False.
|
||||
Check if all self and other are close.
|
||||
"""
|
||||
return bool(self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all().item())
|
||||
return self.isclose(other, rtol=rtol, atol=atol, equal_nan=equal_nan).all()
|
||||
|
||||
def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user