Use is_dtype_supported in more places in tests (#7529)

This commit is contained in:
Ahmed Harmouche
2024-11-04 15:21:15 +01:00
committed by GitHub
parent 1d4df72798
commit 36488a2a43
5 changed files with 34 additions and 17 deletions

View File

@@ -1,11 +1,13 @@
import unittest
from tinygrad import Tensor
from tinygrad import Tensor, dtypes, Device
from tinygrad.engine.realize import lower_schedule
from test.helpers import is_dtype_supported
class TestCompileFailures(unittest.TestCase):
def compile(self, out:Tensor):
for _ in lower_schedule(out.schedule()): pass
@unittest.skipUnless(is_dtype_supported(dtypes.uchar, Device.DEFAULT), f"no uint8 on {Device.DEFAULT}")
def test_interpolate_atari(self):
self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64)))

View File

@@ -130,13 +130,16 @@ class TestMovedConstFolding(unittest.TestCase):
def test_cast_padded(self):
# NOTE: this is folded due to CAST_BEFORE_VIEW
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
if is_dtype_supported(dtypes.int16):
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
if is_dtype_supported(dtypes.uint16):
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# not folded
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
if is_dtype_supported(dtypes.int64):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
class TestReduceOpsConstFolding(unittest.TestCase):
def test_const_sum(self):

View File

@@ -454,6 +454,7 @@ class TestTypeSpec(unittest.TestCase):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
@@ -484,7 +485,8 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@@ -493,20 +495,23 @@ class TestTypeSpec(unittest.TestCase):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
@@ -526,8 +531,10 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.int16):
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
@@ -839,8 +846,9 @@ class TestTensorMethod(unittest.TestCase):
class TestDtypeUsage(unittest.TestCase):
def test_max_w_alu(self):
for d in dtypes.ints:
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
if is_dtype_supported(d):
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
if __name__ == '__main__':
unittest.main()

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
import torch
from tinygrad import Tensor, Device, TinyJit
from tinygrad import Tensor, Device, TinyJit, dtypes
from tinygrad.ops import Ops
from tinygrad.helpers import CI, Context
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
@@ -10,6 +10,7 @@ from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNo
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
class TestNN(unittest.TestCase):
@@ -474,6 +475,7 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28

View File

@@ -5,6 +5,7 @@ import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from test.helpers import is_dtype_supported
if CI:
import warnings
@@ -2180,6 +2181,7 @@ class TestOps(unittest.TestCase):
def test_bitcast(self):
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
class TestOpsUint8(unittest.TestCase):
@unittest.skip('this is broken for negative numbers')
def test_cast(self):