From a4abcf0969689653047dcfd0c0f907f2dd210702 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 12 Mar 2023 22:59:40 -0700 Subject: [PATCH] improve test_example --- test/unit/test_example.py | 51 +++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/test/unit/test_example.py b/test/unit/test_example.py index 63bffa502d..c64c9f6ae2 100644 --- a/test/unit/test_example.py +++ b/test/unit/test_example.py @@ -1,40 +1,59 @@ import unittest +from tinygrad.lazy import Device from tinygrad.tensor import Tensor +def multidevice_test(fxn): + def ret(self): + for device in Device._buffers: + with self.subTest(device=device): + try: + Device[device] + except Exception: + print(f"WARNING: {device} test isn't running") + continue + fxn(self, device) + return ret + class TestExample(unittest.TestCase): - def _test_example_readme(self, device): + @multidevice_test + def test_2_plus_3(self, device): + a = Tensor([2], device=device) + b = Tensor([3], device=device) + result = a + b + print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}") + assert result.numpy()[0] == 5. + + @multidevice_test + def test_example_readme(self, device): x = Tensor.eye(3, device=device, requires_grad=True) y = Tensor([[2.0,0,-2.0]], device=device, requires_grad=True) z = y.matmul(x).sum() z.backward() - print(x.grad.numpy()) # dz/dx - print(y.grad.numpy()) # dz/dy + x.grad.numpy() # dz/dx + y.grad.numpy() # dz/dy assert x.grad.device == device assert y.grad.device == device - def _test_example_matmul(self, device): + @multidevice_test + def test_example_matmul(self, device): + try: + Device[device] + except Exception: + print(f"WARNING: {device} test isn't running") + return + x = Tensor.eye(64, device=device, requires_grad=True) y = Tensor.eye(64, device=device, requires_grad=True) z = y.matmul(x).sum() z.backward() - print(x.grad.numpy()) # dz/dx - print(y.grad.numpy()) # dz/dy + x.grad.numpy() # dz/dx + y.grad.numpy() # dz/dy assert x.grad.device == device assert y.grad.device == device - def test_example_readme_cpu(self): self._test_example_readme("CPU") - def test_example_readme_gpu(self): self._test_example_readme("GPU") - def test_example_readme_torch(self): self._test_example_readme("TORCH") - def test_example_readme_llvm(self): self._test_example_readme("LLVM") - - def test_example_matmul_cpu(self): self._test_example_matmul("CPU") - def test_example_matmul_gpu(self): self._test_example_matmul("GPU") - def test_example_matmul_torch(self): self._test_example_matmul("TORCH") - def test_example_matmul_llvm(self): self._test_example_matmul("LLVM") - if __name__ == '__main__': unittest.main()