add operator.lt and operator.eq to test_dtype_alu (#3191)

* add operator.lt and operator.eq to test_dtype_alu

those should pass now as we have broadcasted before passing to lt and eq.
also updated the test skipping criteria to reuse test_dtype.is_dtype_supported

* llvm lt nan is incorrect

* enable truediv too

* Revert "enable truediv too"

This reverts commit df703235fb.

* just that
This commit is contained in:
chenyu
2024-01-20 14:54:02 -05:00
committed by GitHub
parent c4b5661146
commit 3f56d1a5e8

View File

@@ -5,8 +5,9 @@ import operator
import numpy as np
from hypothesis import given, strategies as strat, settings
from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv, OSX
from tinygrad.helpers import CI, getenv
from tinygrad.ops import UnaryOps, get_lazyop_info
from test.test_dtype import is_dtype_supported
settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")
@@ -15,7 +16,7 @@ print(settings.default)
dtypes_float = (dtypes.float32, dtypes.float16)
dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
dtypes_bool = (dtypes.bool,)
binary_operations = [operator.add, operator.sub, operator.mul]
binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor)]
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
(Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
@@ -26,12 +27,14 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (T
# TODO: enable mod on Tensor
#binary_operations.append(operator.mod)
# TODO: lt and eq should cast in tensor before we can test them, this is a separate project
#binary_operations += [operator.lt, operator.eq]
# TODO: (a+b)/2 in tensor.py's maximum can overflow. This requires a new implementation of maximum that can be backpropagated
#binary_operations += [(Tensor.maximum, np.maximum)]
# TODO: LLVM comparing with nan is incorrect
if Device.DEFAULT == "LLVM":
binary_operations.remove(operator.lt)
binary_operations.remove(operator.eq)
# TODO: CUDACPU segfaults on sin
if getenv("CUDACPU"): unary_operations.remove((Tensor.sin, np.sin))
@@ -87,39 +90,36 @@ def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7)
class TestDTypeALU(unittest.TestCase):
@unittest.skipIf(OSX and Device.DEFAULT in {"GPU", "METAL"}, "no float64 on OSX GPU")
@unittest.skipUnless(is_dtype_supported(dtypes.float64, Device.DEFAULT), f"no float64 on {Device.DEFAULT}")
@given(ht.float64, ht.float64, strat.sampled_from(binary_operations))
def test_float64(self, a, b, op): universal_test(a, b, dtypes.float64, op)
@given(ht.float32, ht.float32, strat.sampled_from(binary_operations))
def test_float32(self, a, b, op): universal_test(a, b, dtypes.float32, op)
# GPU requires cl_khr_fp16
# for LLVM, it segfaults because it can't link to the casting function
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
@unittest.skipIf((Device.DEFAULT in ["GPU", "LLVM"] and CI) or getenv("CUDACPU"), "")
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@given(ht.float16, ht.float16, strat.sampled_from(binary_operations))
def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
@given(ht.float32, strat.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
@unittest.skipIf((Device.DEFAULT in ["GPU", "LLVM"] and CI) or getenv("CUDACPU"), "")
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@given(ht.float16, strat.sampled_from(unary_operations))
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint16 in torch")
@unittest.skipUnless(is_dtype_supported(dtypes.uint16, Device.DEFAULT), f"no uint16 on {Device.DEFAULT}")
@given(ht.uint16, ht.uint16, strat.sampled_from(integer_binary_operations))
def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op)
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
@unittest.skipUnless(is_dtype_supported(dtypes.uint32, Device.DEFAULT), f"no uint32 on {Device.DEFAULT}")
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
@unittest.skipUnless(is_dtype_supported(dtypes.uint64, Device.DEFAULT), f"no uint64 on {Device.DEFAULT}")
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)