diff --git a/test/unit/test_transcendental_helpers.py b/test/unit/test_transcendental_helpers.py index f51ad1c973..126d6bbabf 100644 --- a/test/unit/test_transcendental_helpers.py +++ b/test/unit/test_transcendental_helpers.py @@ -1,8 +1,9 @@ import unittest, math import numpy as np from tinygrad import dtypes +from tinygrad.dtype import DType from tinygrad.ops import UOp, Ops -from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if +from tinygrad.codegen.transcendental import TRANSCENDENTAL_SUPPORTED_DTYPES, payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if, xpow from test.helpers import eval_uop class TestTranscendentalFunctions(unittest.TestCase): @@ -70,5 +71,26 @@ class TestTranscendentalFunctions(unittest.TestCase): np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10) np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63) +class TestVectorizedTranscendetalFunctions(unittest.TestCase): + def _check_all_uops_vectorized(self, u:tuple|UOp, vcount:int): + # check all UOps in u are vectorized with vcount + if isinstance(u, UOp): assert u.dtype.vcount == vcount, f'expected {vcount=} but got {u.dtype.vcount=} for UOp {u=}' + [self._check_all_uops_vectorized(x, vcount) for x in (u if isinstance(u, tuple) else u.src)] + + def _get_inputs(self) -> tuple[UOp, DType]: + for val in [-2,1.3,194]: + for vcount in [1,2,4,19]: + for _dtype in TRANSCENDENTAL_SUPPORTED_DTYPES: + dtype: DType = _dtype.vec(vcount) + d = UOp.const(dtype, val) + yield d, dtype + + def test_preserves_vectorization(self): + # verify that when given a vectorized (or scalar) input, the function returns a vectorized (or scalar) output + for d, dtype in self._get_inputs(): + self._check_all_uops_vectorized(payne_hanek_reduction(d), dtype.vcount) + self._check_all_uops_vectorized(cody_waite_reduction(d), dtype.vcount) + self._check_all_uops_vectorized(xpow(d, d), dtype.vcount) + if __name__ == '__main__': unittest.main()