diff --git a/test/unit/test_dtype_spec.py b/test/unit/test_dtype_spec.py index 65460f5e53..da690a8170 100644 --- a/test/unit/test_dtype_spec.py +++ b/test/unit/test_dtype_spec.py @@ -27,6 +27,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0] +def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0] class TestHelpers(unittest.TestCase): signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) @@ -117,7 +118,7 @@ class TestHelpers(unittest.TestCase): self.assertEqual(float_to_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item()) self.assertTrue(math.isnan(float_to_bf16(math.nan))) - def test_truncate_bf16_nan(self): + def test_float_to_bf16_nan(self): # In f32, NaN = exp 0xFF and mantissa ≠ 0. Quiet-vs-signaling is bit 22 of the mantissa: 1 = qNaN, 0 = sNaN. # qNaN(+/-), sNaN(+/-) overflow(+/-) patterns = [0x7FC00001, 0xFFC00001, 0x7F800001, 0xFF800001, 0x7FFFFFFF, 0xFFFFFFFF] @@ -128,6 +129,56 @@ class TestHelpers(unittest.TestCase): self.assertTrue(math.isnan(y)) self.assertTrue(math.isnan(t)) + def test_float_to_bf16_round(self): + # round_to_nearest_even + uppers = [0x3f800000, 0x41230000, 0xC1460000] # 1.0, 10.1875, -12.375 + for upper in uppers: + base = upper & 0xFFFF0000 + base_f32 = u32_to_f32(base) + base_f32_round_up = u32_to_f32(base + 0x00010000) + + # low < 0x8000(0.5ULP) -> round down + x = u32_to_f32(base | 0x00007000) + self.assertEqual(float_to_bf16(x), base_f32) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32) + + # low > 0x8000(0.5ULP) -> round up + x = u32_to_f32(base | 0x0000C000) + self.assertEqual(float_to_bf16(x), base_f32_round_up) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up) + + # low == 0x8000(0.5ULP) and LSB even -> round down + if ((upper >> 16) & 1) == 0: + x = u32_to_f32(base | 0x00008000) + self.assertEqual(float_to_bf16(x), base_f32) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32) + # low == 0x8000(0.5ULP) and LSB odd -> round up + else: + x = u32_to_f32(base | 0x00008000) + self.assertEqual(float_to_bf16(x), base_f32_round_up) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up) + + def test_float_to_bf16_boundary(self): + # bf16 max finite: exp=0xFE, faction=0x7F => 0x7F7F0000(f32) + # bf16 inf(+/-): exp=0xFF + base = 0x7F7F0000 + inf_u32 = 0x7F800000 + + # low < 0.5ULP + x = u32_to_f32(base | 0x00007FFF) + self.assertEqual(f32_to_u32(float_to_bf16(x)), base) + self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), base) + + # low > 0.5ULP -> overflows to +inf + x = u32_to_f32(base | 0x0000C000) + self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32) + self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32) + + # low == 0.5ULP and LSB odd -> overflows to +inf + x = u32_to_f32(base | 0x00008000) + self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32) + self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32) + @given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True)) def test_truncate_fp8e4m3(self, x): if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)