mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
FP8s truncate (#9937)
* truncate fp8 * fix * maybe like that? * fix linters * ruff * move from extra and add ml_types to tests * minor changes * str to dtypes and nan support --------- Co-authored-by: pkotzbach <pawkotz@gmail.com>
This commit is contained in:
1
setup.py
1
setup.py
@@ -14,6 +14,7 @@ testing_minimal = [
|
||||
"pytest-xdist",
|
||||
"hypothesis==6.131.0",
|
||||
"z3-solver",
|
||||
"ml_dtypes"
|
||||
]
|
||||
|
||||
setup(name='tinygrad',
|
||||
|
||||
@@ -5,10 +5,12 @@ from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, truncate_bf16, to_dtype
|
||||
from tinygrad.dtype import truncate, fp8_to_float, float_to_fp8
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
from test.helpers import rand_for_dtype
|
||||
import ml_dtypes
|
||||
import pytest
|
||||
pytestmark = pytest.mark.filterwarnings("ignore")
|
||||
|
||||
@@ -19,6 +21,8 @@ core_dtypes = list(DTYPES_DICT.values())
|
||||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
FP8E4M3_MAX = 448.0
|
||||
FP8E5M2_MAX = 57344.0
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype): return []
|
||||
@@ -146,6 +150,39 @@ class TestFp8s(unittest.TestCase):
|
||||
def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3
|
||||
def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2
|
||||
|
||||
class TestFp8sConversions(unittest.TestCase):
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
|
||||
def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0])
|
||||
|
||||
def test_float_to_fp8e4m3_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127)
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX))
|
||||
def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0])
|
||||
|
||||
def test_float_to_fp8e5m2_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126)
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
|
||||
|
||||
@given(strat.integers(min_value=0, max_value=255))
|
||||
def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item())
|
||||
|
||||
@given(strat.integers(min_value=0, max_value=255))
|
||||
def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item())
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16(unittest.TestCase):
|
||||
def test_bf16_creation_numpy(self):
|
||||
@@ -459,6 +496,18 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf)
|
||||
self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf)
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e4m3(self, x):
|
||||
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e5m2(self, x):
|
||||
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
|
||||
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
|
||||
@@ -198,8 +198,71 @@ def truncate_bf16(x):
|
||||
bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
|
||||
return bf
|
||||
|
||||
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
|
||||
def float_to_fp8(x: float, dtype: DType) -> int:
|
||||
assert dtype in dtypes.fp8s, "Only for fp8s"
|
||||
config = {
|
||||
dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000,
|
||||
"OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F},
|
||||
dtypes.fp8e5m2: {"EXP_BIAS": 15, "SIGNIFICAND_BITS": 3, "MANTISSA_MASK": 0x3, "MINDENORM_O2": 0x3EE0000000000000,
|
||||
"OVERFLOW_THRESHOLD": 0x40EE000000000000 - 1, "MAXNORM": 0x7B, "MINNORM": 0x3F10000000000000, "INF_VALUE": 0x7E}
|
||||
}[dtype]
|
||||
xbits, = struct.unpack('Q', struct.pack('d', x))
|
||||
FP8_DP_HALF_ULP = 1 << (53 - config["SIGNIFICAND_BITS"] - 1)
|
||||
sign = ((xbits >> 63) & 1) << 7
|
||||
exp = (((xbits >> 52) & 0x7FF) - 1023 + config["EXP_BIAS"])
|
||||
mantissa = (xbits >> (53 - config["SIGNIFICAND_BITS"])) & config["MANTISSA_MASK"]
|
||||
absx = xbits & 0x7FFFFFFFFFFFFFFF
|
||||
|
||||
if absx <= config["MINDENORM_O2"]: res = 0
|
||||
elif absx > 0x7FF0000000000000: res = 0x7F if dtype == dtypes.fp8e4m3 else 0x7E | mantissa
|
||||
elif absx > config["OVERFLOW_THRESHOLD"]: res = config["MAXNORM"]
|
||||
elif absx >= config["MINNORM"]:
|
||||
res = ((exp << (config["SIGNIFICAND_BITS"] - 1)) | mantissa)
|
||||
round_bits = xbits & ((FP8_DP_HALF_ULP << 1) - 1)
|
||||
if (round_bits > FP8_DP_HALF_ULP) or (round_bits == FP8_DP_HALF_ULP and (mantissa & 1)): res = res + 1
|
||||
else:
|
||||
shift = 1 - exp
|
||||
mantissa |= 1 << (config["SIGNIFICAND_BITS"] - 1)
|
||||
res = (mantissa >> shift)
|
||||
round_bits = (xbits | (1 << (53 - 1))) & ((FP8_DP_HALF_ULP << (shift + 1)) - 1)
|
||||
if (round_bits > (FP8_DP_HALF_ULP << shift)) or (round_bits == (FP8_DP_HALF_ULP << shift) and (res & 1)):
|
||||
res = res + 1
|
||||
|
||||
res |= sign
|
||||
return int(res)
|
||||
|
||||
def fp8_to_float(x: int, dtype: DType) -> float:
|
||||
assert dtype in dtypes.fp8s, "Only for fp8s"
|
||||
ur = x << 8
|
||||
|
||||
if dtype == dtypes.fp8e5m2 and (ur & 0x7FFF) > 0x7C00: ur = 0x7FFF
|
||||
elif dtype == dtypes.fp8e4m3:
|
||||
sign = ur & 0x8000
|
||||
exponent = ((ur & 0x7800) >> 1) + 0x2000
|
||||
mantissa = (ur & 0x0700) >> 1
|
||||
absx = x & 0x7F
|
||||
if absx == 0x7F: ur = 0x7FFF
|
||||
elif exponent == 0x2000:
|
||||
if mantissa != 0:
|
||||
mantissa <<= 1
|
||||
while (mantissa & 0x0400) == 0:
|
||||
mantissa <<= 1
|
||||
exponent -= 0x0400
|
||||
mantissa &= 0x03FF
|
||||
else:
|
||||
exponent = 0
|
||||
ur = (sign | exponent) | mantissa
|
||||
else:
|
||||
ur = (sign | exponent) | mantissa
|
||||
|
||||
half_bytes = struct.pack('<H', ur)
|
||||
float32_val = struct.unpack('e', half_bytes)[0]
|
||||
return float(float32_val)
|
||||
|
||||
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
|
||||
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
|
||||
Reference in New Issue
Block a user