mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Use is_dtype_supported in more places in tests (#7529)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
class TestCompileFailures(unittest.TestCase):
|
||||
def compile(self, out:Tensor):
|
||||
for _ in lower_schedule(out.schedule()): pass
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar, Device.DEFAULT), f"no uint8 on {Device.DEFAULT}")
|
||||
def test_interpolate_atari(self):
|
||||
self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64)))
|
||||
|
||||
|
||||
@@ -130,13 +130,16 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: this is folded due to CAST_BEFORE_VIEW
|
||||
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
if is_dtype_supported(dtypes.uint16):
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# not folded
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
|
||||
class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
def test_const_sum(self):
|
||||
|
||||
@@ -454,6 +454,7 @@ class TestTypeSpec(unittest.TestCase):
|
||||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
@@ -484,7 +485,8 @@ class TestTypeSpec(unittest.TestCase):
|
||||
|
||||
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
||||
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
||||
|
||||
@@ -493,20 +495,23 @@ class TestTypeSpec(unittest.TestCase):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
||||
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
@@ -526,8 +531,10 @@ class TestTypeSpec(unittest.TestCase):
|
||||
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
|
||||
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
|
||||
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
|
||||
@@ -839,8 +846,9 @@ class TestTensorMethod(unittest.TestCase):
|
||||
class TestDtypeUsage(unittest.TestCase):
|
||||
def test_max_w_alu(self):
|
||||
for d in dtypes.ints:
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
if is_dtype_supported(d):
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad import Tensor, Device, TinyJit
|
||||
from tinygrad import Tensor, Device, TinyJit, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
@@ -10,6 +10,7 @@ from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNo
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
||||
class TestNN(unittest.TestCase):
|
||||
@@ -474,6 +475,7 @@ class TestNN(unittest.TestCase):
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_embedding(self):
|
||||
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
if CI:
|
||||
import warnings
|
||||
@@ -2180,6 +2181,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
@unittest.skip('this is broken for negative numbers')
|
||||
def test_cast(self):
|
||||
|
||||
Reference in New Issue
Block a user