From 36488a2a4398a4e1f7d16ebe6fcd9df239de4bbe Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Mon, 4 Nov 2024 15:21:15 +0100 Subject: [PATCH] Use is_dtype_supported in more places in tests (#7529) --- test/test_compile_failures.py | 4 +++- test/test_const_folding.py | 15 +++++++++------ test/test_dtype.py | 26 +++++++++++++++++--------- test/test_nn.py | 4 +++- test/test_ops.py | 2 ++ 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/test/test_compile_failures.py b/test/test_compile_failures.py index b2b03d42e7..baa20c6316 100644 --- a/test/test_compile_failures.py +++ b/test/test_compile_failures.py @@ -1,11 +1,13 @@ import unittest -from tinygrad import Tensor +from tinygrad import Tensor, dtypes, Device from tinygrad.engine.realize import lower_schedule +from test.helpers import is_dtype_supported class TestCompileFailures(unittest.TestCase): def compile(self, out:Tensor): for _ in lower_schedule(out.schedule()): pass + @unittest.skipUnless(is_dtype_supported(dtypes.uchar, Device.DEFAULT), f"no uint8 on {Device.DEFAULT}") def test_interpolate_atari(self): self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64))) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 9fa559e20d..11f9877136 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -130,13 +130,16 @@ class TestMovedConstFolding(unittest.TestCase): def test_cast_padded(self): # NOTE: this is folded due to CAST_BEFORE_VIEW - _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16)) - np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0]) - _check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16)) - np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0]) + if is_dtype_supported(dtypes.int16): + _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16)) + np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0]) + if is_dtype_supported(dtypes.uint16): + _check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16)) + np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0]) # not folded - _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64)) - np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0]) + if is_dtype_supported(dtypes.int64): + _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64)) + np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0]) class TestReduceOpsConstFolding(unittest.TestCase): def test_const_sum(self): diff --git a/test/test_dtype.py b/test/test_dtype.py index faaff6b3a6..6cf53dca62 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -454,6 +454,7 @@ class TestTypeSpec(unittest.TestCase): subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'], shell=True, check=True) + @unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}") def test_dtype_str_arg(self): n = np.random.normal(0, 1, (10, 10)).astype(np.float32) tested = 0 @@ -484,7 +485,8 @@ class TestTypeSpec(unittest.TestCase): _assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0)) _assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3)) - _assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3)) + if is_dtype_supported(dtypes.int64): + _assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3)) if is_dtype_supported(dtypes.float16): _assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3)) @@ -493,20 +495,23 @@ class TestTypeSpec(unittest.TestCase): dtypes.default_int, dtypes.default_float = default_int, default_float _assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3))) - _assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3))) + if is_dtype_supported(dtypes.int64): + _assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3))) if is_dtype_supported(dtypes.float16): _assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3))) _assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3))) - _assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3))) + if is_dtype_supported(dtypes.int64): + _assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3))) if is_dtype_supported(dtypes.float16): _assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3))) _assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0)) _assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3)) _assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True)) - _assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3)) - _assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3)) + if is_dtype_supported(dtypes.int64): + _assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3)) + _assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3)) if is_dtype_supported(dtypes.float16): _assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3)) _assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3)) @@ -526,8 +531,10 @@ class TestTypeSpec(unittest.TestCase): _assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5)) _assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120)) _assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5)) - _assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5)) - _assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5)) + if is_dtype_supported(dtypes.int16): + _assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5)) + if is_dtype_supported(dtypes.int64): + _assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5)) if is_dtype_supported(dtypes.float16): _assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5)) _assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7)) @@ -839,8 +846,9 @@ class TestTensorMethod(unittest.TestCase): class TestDtypeUsage(unittest.TestCase): def test_max_w_alu(self): for d in dtypes.ints: - t = Tensor([[1, 2], [3, 4]], dtype=d) - (t*t).max().item() + if is_dtype_supported(d): + t = Tensor([[1, 2], [3, 4]], dtype=d) + (t*t).max().item() if __name__ == '__main__': unittest.main() diff --git a/test/test_nn.py b/test/test_nn.py index e5c241c544..c36db2f706 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2,7 +2,7 @@ import unittest import numpy as np import torch -from tinygrad import Tensor, Device, TinyJit +from tinygrad import Tensor, Device, TinyJit, dtypes from tinygrad.ops import Ops from tinygrad.helpers import CI, Context from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding @@ -10,6 +10,7 @@ from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNo from tinygrad.nn.state import load_state_dict from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule +from test.helpers import is_dtype_supported @unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow") class TestNN(unittest.TestCase): @@ -474,6 +475,7 @@ class TestNN(unittest.TestCase): np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3) + @unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}") def test_embedding(self): B, T, embed_size, vocab_size = 4, 10, 20, 28 diff --git a/test/test_ops.py b/test/test_ops.py index 3c7cf939c2..402eaa4684 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -5,6 +5,7 @@ import torch from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype +from test.helpers import is_dtype_supported if CI: import warnings @@ -2180,6 +2181,7 @@ class TestOps(unittest.TestCase): def test_bitcast(self): helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True) +@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}") class TestOpsUint8(unittest.TestCase): @unittest.skip('this is broken for negative numbers') def test_cast(self):