test float_to_bf16 round-to-even behavior (#11849)

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg
2025-08-27 00:16:10 +08:00
committed by GitHub
parent 409399c609
commit 1dd613cb89

View File

@@ -27,6 +27,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0]
def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0]
class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
@@ -117,7 +118,7 @@ class TestHelpers(unittest.TestCase):
self.assertEqual(float_to_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
self.assertTrue(math.isnan(float_to_bf16(math.nan)))
def test_truncate_bf16_nan(self):
def test_float_to_bf16_nan(self):
# In f32, NaN = exp 0xFF and mantissa ≠ 0. Quiet-vs-signaling is bit 22 of the mantissa: 1 = qNaN, 0 = sNaN.
# qNaN(+/-), sNaN(+/-) overflow(+/-)
patterns = [0x7FC00001, 0xFFC00001, 0x7F800001, 0xFF800001, 0x7FFFFFFF, 0xFFFFFFFF]
@@ -128,6 +129,56 @@ class TestHelpers(unittest.TestCase):
self.assertTrue(math.isnan(y))
self.assertTrue(math.isnan(t))
def test_float_to_bf16_round(self):
# round_to_nearest_even
uppers = [0x3f800000, 0x41230000, 0xC1460000] # 1.0, 10.1875, -12.375
for upper in uppers:
base = upper & 0xFFFF0000
base_f32 = u32_to_f32(base)
base_f32_round_up = u32_to_f32(base + 0x00010000)
# low < 0x8000(0.5ULP) -> round down
x = u32_to_f32(base | 0x00007000)
self.assertEqual(float_to_bf16(x), base_f32)
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32)
# low > 0x8000(0.5ULP) -> round up
x = u32_to_f32(base | 0x0000C000)
self.assertEqual(float_to_bf16(x), base_f32_round_up)
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up)
# low == 0x8000(0.5ULP) and LSB even -> round down
if ((upper >> 16) & 1) == 0:
x = u32_to_f32(base | 0x00008000)
self.assertEqual(float_to_bf16(x), base_f32)
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32)
# low == 0x8000(0.5ULP) and LSB odd -> round up
else:
x = u32_to_f32(base | 0x00008000)
self.assertEqual(float_to_bf16(x), base_f32_round_up)
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up)
def test_float_to_bf16_boundary(self):
# bf16 max finite: exp=0xFE, faction=0x7F => 0x7F7F0000(f32)
# bf16 inf(+/-): exp=0xFF
base = 0x7F7F0000
inf_u32 = 0x7F800000
# low < 0.5ULP
x = u32_to_f32(base | 0x00007FFF)
self.assertEqual(f32_to_u32(float_to_bf16(x)), base)
self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), base)
# low > 0.5ULP -> overflows to +inf
x = u32_to_f32(base | 0x0000C000)
self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32)
self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32)
# low == 0.5ULP and LSB odd -> overflows to +inf
x = u32_to_f32(base | 0x00008000)
self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32)
self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32)
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e4m3(self, x):
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)