mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
test float_to_bf16 round-to-even behavior (#11849)
Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0]
|
||||
def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0]
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
||||
@@ -117,7 +118,7 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertEqual(float_to_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
|
||||
self.assertTrue(math.isnan(float_to_bf16(math.nan)))
|
||||
|
||||
def test_truncate_bf16_nan(self):
|
||||
def test_float_to_bf16_nan(self):
|
||||
# In f32, NaN = exp 0xFF and mantissa ≠ 0. Quiet-vs-signaling is bit 22 of the mantissa: 1 = qNaN, 0 = sNaN.
|
||||
# qNaN(+/-), sNaN(+/-) overflow(+/-)
|
||||
patterns = [0x7FC00001, 0xFFC00001, 0x7F800001, 0xFF800001, 0x7FFFFFFF, 0xFFFFFFFF]
|
||||
@@ -128,6 +129,56 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertTrue(math.isnan(y))
|
||||
self.assertTrue(math.isnan(t))
|
||||
|
||||
def test_float_to_bf16_round(self):
|
||||
# round_to_nearest_even
|
||||
uppers = [0x3f800000, 0x41230000, 0xC1460000] # 1.0, 10.1875, -12.375
|
||||
for upper in uppers:
|
||||
base = upper & 0xFFFF0000
|
||||
base_f32 = u32_to_f32(base)
|
||||
base_f32_round_up = u32_to_f32(base + 0x00010000)
|
||||
|
||||
# low < 0x8000(0.5ULP) -> round down
|
||||
x = u32_to_f32(base | 0x00007000)
|
||||
self.assertEqual(float_to_bf16(x), base_f32)
|
||||
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32)
|
||||
|
||||
# low > 0x8000(0.5ULP) -> round up
|
||||
x = u32_to_f32(base | 0x0000C000)
|
||||
self.assertEqual(float_to_bf16(x), base_f32_round_up)
|
||||
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up)
|
||||
|
||||
# low == 0x8000(0.5ULP) and LSB even -> round down
|
||||
if ((upper >> 16) & 1) == 0:
|
||||
x = u32_to_f32(base | 0x00008000)
|
||||
self.assertEqual(float_to_bf16(x), base_f32)
|
||||
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32)
|
||||
# low == 0x8000(0.5ULP) and LSB odd -> round up
|
||||
else:
|
||||
x = u32_to_f32(base | 0x00008000)
|
||||
self.assertEqual(float_to_bf16(x), base_f32_round_up)
|
||||
self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up)
|
||||
|
||||
def test_float_to_bf16_boundary(self):
|
||||
# bf16 max finite: exp=0xFE, faction=0x7F => 0x7F7F0000(f32)
|
||||
# bf16 inf(+/-): exp=0xFF
|
||||
base = 0x7F7F0000
|
||||
inf_u32 = 0x7F800000
|
||||
|
||||
# low < 0.5ULP
|
||||
x = u32_to_f32(base | 0x00007FFF)
|
||||
self.assertEqual(f32_to_u32(float_to_bf16(x)), base)
|
||||
self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), base)
|
||||
|
||||
# low > 0.5ULP -> overflows to +inf
|
||||
x = u32_to_f32(base | 0x0000C000)
|
||||
self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32)
|
||||
self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32)
|
||||
|
||||
# low == 0.5ULP and LSB odd -> overflows to +inf
|
||||
x = u32_to_f32(base | 0x00008000)
|
||||
self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32)
|
||||
self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32)
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e4m3(self, x):
|
||||
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
|
||||
Reference in New Issue
Block a user