diff --git a/setup.py b/setup.py index 356450e0a4..e132b94169 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ testing_minimal = [ "pytest-xdist", "hypothesis==6.131.0", "z3-solver", + "ml_dtypes" ] setup(name='tinygrad', diff --git a/test/test_dtype.py b/test/test_dtype.py index 7773eb4848..b41a9663a3 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -5,10 +5,12 @@ from typing import Any, List from tinygrad.device import is_dtype_supported from tinygrad.helpers import getenv, DEBUG, CI from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, truncate_bf16, to_dtype +from tinygrad.dtype import truncate, fp8_to_float, float_to_fp8 from tinygrad import Device, Tensor, dtypes from tinygrad.tensor import _to_np_dtype from hypothesis import assume, given, settings, strategies as strat from test.helpers import rand_for_dtype +import ml_dtypes import pytest pytestmark = pytest.mark.filterwarnings("ignore") @@ -19,6 +21,8 @@ core_dtypes = list(DTYPES_DICT.values()) if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)] dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)] +FP8E4M3_MAX = 448.0 +FP8E5M2_MAX = 57344.0 def get_available_cast_dtypes(dtype: DType) -> List[DType]: if not is_dtype_supported(dtype): return [] @@ -146,6 +150,39 @@ class TestFp8s(unittest.TestCase): def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3 def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2 +class TestFp8sConversions(unittest.TestCase): + @given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX)) + def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0]) + + def test_float_to_fp8e4m3_extreme_values(self): + np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126) + np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126) + np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126) + np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254) + np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254) + np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254) + np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127) + np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255) + + @given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX)) + def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0]) + + def test_float_to_fp8e5m2_extreme_values(self): + np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123) + np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123) + np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123) + np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251) + np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251) + np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251) + np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126) + np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254) + + @given(strat.integers(min_value=0, max_value=255)) + def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item()) + + @given(strat.integers(min_value=0, max_value=255)) + def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item()) + @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported") class TestBFloat16(unittest.TestCase): def test_bf16_creation_numpy(self): @@ -459,6 +496,18 @@ class TestHelpers(unittest.TestCase): self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf) self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf) + @given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True)) + def test_truncate_fp8e4m3(self, x): + if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX) + elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX) + else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x)) + + @given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True)) + def test_truncate_fp8e5m2(self, x): + if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX) + elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX) + else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x)) + class TestTypeSpec(unittest.TestCase): def setUp(self): self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index cffc822649..12424ce980 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -198,8 +198,71 @@ def truncate_bf16(x): bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0] return bf +# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp +def float_to_fp8(x: float, dtype: DType) -> int: + assert dtype in dtypes.fp8s, "Only for fp8s" + config = { + dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000, + "OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F}, + dtypes.fp8e5m2: {"EXP_BIAS": 15, "SIGNIFICAND_BITS": 3, "MANTISSA_MASK": 0x3, "MINDENORM_O2": 0x3EE0000000000000, + "OVERFLOW_THRESHOLD": 0x40EE000000000000 - 1, "MAXNORM": 0x7B, "MINNORM": 0x3F10000000000000, "INF_VALUE": 0x7E} + }[dtype] + xbits, = struct.unpack('Q', struct.pack('d', x)) + FP8_DP_HALF_ULP = 1 << (53 - config["SIGNIFICAND_BITS"] - 1) + sign = ((xbits >> 63) & 1) << 7 + exp = (((xbits >> 52) & 0x7FF) - 1023 + config["EXP_BIAS"]) + mantissa = (xbits >> (53 - config["SIGNIFICAND_BITS"])) & config["MANTISSA_MASK"] + absx = xbits & 0x7FFFFFFFFFFFFFFF + + if absx <= config["MINDENORM_O2"]: res = 0 + elif absx > 0x7FF0000000000000: res = 0x7F if dtype == dtypes.fp8e4m3 else 0x7E | mantissa + elif absx > config["OVERFLOW_THRESHOLD"]: res = config["MAXNORM"] + elif absx >= config["MINNORM"]: + res = ((exp << (config["SIGNIFICAND_BITS"] - 1)) | mantissa) + round_bits = xbits & ((FP8_DP_HALF_ULP << 1) - 1) + if (round_bits > FP8_DP_HALF_ULP) or (round_bits == FP8_DP_HALF_ULP and (mantissa & 1)): res = res + 1 + else: + shift = 1 - exp + mantissa |= 1 << (config["SIGNIFICAND_BITS"] - 1) + res = (mantissa >> shift) + round_bits = (xbits | (1 << (53 - 1))) & ((FP8_DP_HALF_ULP << (shift + 1)) - 1) + if (round_bits > (FP8_DP_HALF_ULP << shift)) or (round_bits == (FP8_DP_HALF_ULP << shift) and (res & 1)): + res = res + 1 + + res |= sign + return int(res) + +def fp8_to_float(x: int, dtype: DType) -> float: + assert dtype in dtypes.fp8s, "Only for fp8s" + ur = x << 8 + + if dtype == dtypes.fp8e5m2 and (ur & 0x7FFF) > 0x7C00: ur = 0x7FFF + elif dtype == dtypes.fp8e4m3: + sign = ur & 0x8000 + exponent = ((ur & 0x7800) >> 1) + 0x2000 + mantissa = (ur & 0x0700) >> 1 + absx = x & 0x7F + if absx == 0x7F: ur = 0x7FFF + elif exponent == 0x2000: + if mantissa != 0: + mantissa <<= 1 + while (mantissa & 0x0400) == 0: + mantissa <<= 1 + exponent -= 0x0400 + mantissa &= 0x03FF + else: + exponent = 0 + ur = (sign | exponent) | mantissa + else: + ur = (sign | exponent) | mantissa + + half_bytes = struct.pack('