mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
add operator.lt and operator.eq to test_dtype_alu (#3191)
* add operator.lt and operator.eq to test_dtype_alu
those should pass now as we have broadcasted before passing to lt and eq.
also updated the test skipping criteria to reuse test_dtype.is_dtype_supported
* llvm lt nan is incorrect
* enable truediv too
* Revert "enable truediv too"
This reverts commit df703235fb.
* just that
This commit is contained in:
@@ -5,8 +5,9 @@ import operator
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv, OSX
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.ops import UnaryOps, get_lazyop_info
|
||||
from test.test_dtype import is_dtype_supported
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None)
|
||||
settings.load_profile("my_profile")
|
||||
@@ -15,7 +16,7 @@ print(settings.default)
|
||||
dtypes_float = (dtypes.float32, dtypes.float16)
|
||||
dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
dtypes_bool = (dtypes.bool,)
|
||||
binary_operations = [operator.add, operator.sub, operator.mul]
|
||||
binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor)]
|
||||
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
|
||||
(Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
|
||||
@@ -26,12 +27,14 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (T
|
||||
# TODO: enable mod on Tensor
|
||||
#binary_operations.append(operator.mod)
|
||||
|
||||
# TODO: lt and eq should cast in tensor before we can test them, this is a separate project
|
||||
#binary_operations += [operator.lt, operator.eq]
|
||||
|
||||
# TODO: (a+b)/2 in tensor.py's maximum can overflow. This requires a new implementation of maximum that can be backpropagated
|
||||
#binary_operations += [(Tensor.maximum, np.maximum)]
|
||||
|
||||
# TODO: LLVM comparing with nan is incorrect
|
||||
if Device.DEFAULT == "LLVM":
|
||||
binary_operations.remove(operator.lt)
|
||||
binary_operations.remove(operator.eq)
|
||||
|
||||
# TODO: CUDACPU segfaults on sin
|
||||
if getenv("CUDACPU"): unary_operations.remove((Tensor.sin, np.sin))
|
||||
|
||||
@@ -87,39 +90,36 @@ def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7)
|
||||
|
||||
class TestDTypeALU(unittest.TestCase):
|
||||
@unittest.skipIf(OSX and Device.DEFAULT in {"GPU", "METAL"}, "no float64 on OSX GPU")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float64, Device.DEFAULT), f"no float64 on {Device.DEFAULT}")
|
||||
@given(ht.float64, ht.float64, strat.sampled_from(binary_operations))
|
||||
def test_float64(self, a, b, op): universal_test(a, b, dtypes.float64, op)
|
||||
|
||||
@given(ht.float32, ht.float32, strat.sampled_from(binary_operations))
|
||||
def test_float32(self, a, b, op): universal_test(a, b, dtypes.float32, op)
|
||||
|
||||
# GPU requires cl_khr_fp16
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
||||
@unittest.skipIf((Device.DEFAULT in ["GPU", "LLVM"] and CI) or getenv("CUDACPU"), "")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
@given(ht.float16, ht.float16, strat.sampled_from(binary_operations))
|
||||
def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
|
||||
|
||||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
@unittest.skipIf((Device.DEFAULT in ["GPU", "LLVM"] and CI) or getenv("CUDACPU"), "")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
@given(ht.float16, strat.sampled_from(unary_operations))
|
||||
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
|
||||
|
||||
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint16 in torch")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint16, Device.DEFAULT), f"no uint16 on {Device.DEFAULT}")
|
||||
@given(ht.uint16, ht.uint16, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint32, Device.DEFAULT), f"no uint32 on {Device.DEFAULT}")
|
||||
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64, Device.DEFAULT), f"no uint64 on {Device.DEFAULT}")
|
||||
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user