FP8s truncate (#9937)

* truncate fp8

* fix

* maybe like that?

* fix linters

* ruff

* move from extra and add ml_types to tests

* minor changes

* str to dtypes and nan support

---------

Co-authored-by: pkotzbach <pawkotz@gmail.com>
This commit is contained in:
pkotzbach
2025-04-23 01:12:49 +02:00
committed by GitHub
parent 58180caad3
commit dbbd755cba
3 changed files with 113 additions and 0 deletions

View File

@@ -14,6 +14,7 @@ testing_minimal = [
"pytest-xdist",
"hypothesis==6.131.0",
"z3-solver",
"ml_dtypes"
]
setup(name='tinygrad',

View File

@@ -5,10 +5,12 @@ from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, truncate_bf16, to_dtype
from tinygrad.dtype import truncate, fp8_to_float, float_to_fp8
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import assume, given, settings, strategies as strat
from test.helpers import rand_for_dtype
import ml_dtypes
import pytest
pytestmark = pytest.mark.filterwarnings("ignore")
@@ -19,6 +21,8 @@ core_dtypes = list(DTYPES_DICT.values())
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
FP8E4M3_MAX = 448.0
FP8E5M2_MAX = 57344.0
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
@@ -146,6 +150,39 @@ class TestFp8s(unittest.TestCase):
def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3
def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2
class TestFp8sConversions(unittest.TestCase):
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0])
def test_float_to_fp8e4m3_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX))
def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0])
def test_float_to_fp8e5m2_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item())
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item())
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
class TestBFloat16(unittest.TestCase):
def test_bf16_creation_numpy(self):
@@ -459,6 +496,18 @@ class TestHelpers(unittest.TestCase):
self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf)
self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf)
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e4m3(self, x):
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e5m2(self, x):
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float

View File

@@ -198,8 +198,71 @@ def truncate_bf16(x):
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
return bf
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
def float_to_fp8(x: float, dtype: DType) -> int:
assert dtype in dtypes.fp8s, "Only for fp8s"
config = {
dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000,
"OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F},
dtypes.fp8e5m2: {"EXP_BIAS": 15, "SIGNIFICAND_BITS": 3, "MANTISSA_MASK": 0x3, "MINDENORM_O2": 0x3EE0000000000000,
"OVERFLOW_THRESHOLD": 0x40EE000000000000 - 1, "MAXNORM": 0x7B, "MINNORM": 0x3F10000000000000, "INF_VALUE": 0x7E}
}[dtype]
xbits, = struct.unpack('Q', struct.pack('d', x))
FP8_DP_HALF_ULP = 1 << (53 - config["SIGNIFICAND_BITS"] - 1)
sign = ((xbits >> 63) & 1) << 7
exp = (((xbits >> 52) & 0x7FF) - 1023 + config["EXP_BIAS"])
mantissa = (xbits >> (53 - config["SIGNIFICAND_BITS"])) & config["MANTISSA_MASK"]
absx = xbits & 0x7FFFFFFFFFFFFFFF
if absx <= config["MINDENORM_O2"]: res = 0
elif absx > 0x7FF0000000000000: res = 0x7F if dtype == dtypes.fp8e4m3 else 0x7E | mantissa
elif absx > config["OVERFLOW_THRESHOLD"]: res = config["MAXNORM"]
elif absx >= config["MINNORM"]:
res = ((exp << (config["SIGNIFICAND_BITS"] - 1)) | mantissa)
round_bits = xbits & ((FP8_DP_HALF_ULP << 1) - 1)
if (round_bits > FP8_DP_HALF_ULP) or (round_bits == FP8_DP_HALF_ULP and (mantissa & 1)): res = res + 1
else:
shift = 1 - exp
mantissa |= 1 << (config["SIGNIFICAND_BITS"] - 1)
res = (mantissa >> shift)
round_bits = (xbits | (1 << (53 - 1))) & ((FP8_DP_HALF_ULP << (shift + 1)) - 1)
if (round_bits > (FP8_DP_HALF_ULP << shift)) or (round_bits == (FP8_DP_HALF_ULP << shift) and (res & 1)):
res = res + 1
res |= sign
return int(res)
def fp8_to_float(x: int, dtype: DType) -> float:
assert dtype in dtypes.fp8s, "Only for fp8s"
ur = x << 8
if dtype == dtypes.fp8e5m2 and (ur & 0x7FFF) > 0x7C00: ur = 0x7FFF
elif dtype == dtypes.fp8e4m3:
sign = ur & 0x8000
exponent = ((ur & 0x7800) >> 1) + 0x2000
mantissa = (ur & 0x0700) >> 1
absx = x & 0x7F
if absx == 0x7F: ur = 0x7FFF
elif exponent == 0x2000:
if mantissa != 0:
mantissa <<= 1
while (mantissa & 0x0400) == 0:
mantissa <<= 1
exponent -= 0x0400
mantissa &= 0x03FF
else:
exponent = 0
ur = (sign | exponent) | mantissa
else:
ur = (sign | exponent) | mantissa
half_bytes = struct.pack('<H', ur)
float32_val = struct.unpack('e', half_bytes)[0]
return float(float32_val)
truncate: dict[DType, Callable] = {dtypes.bool: bool,
dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,