diff --git a/test/test_dtype.py b/test/test_dtype.py index 6311b4046e..067100955a 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -142,13 +142,11 @@ class TestBFloat16(unittest.TestCase): np.testing.assert_allclose(tnp, np.array(data)) def test_bf16_ones(self): - # TODO: fix this with correct bfloat16 cast t = Tensor.ones(3, 5, dtype=dtypes.bfloat16) assert t.dtype == dtypes.bfloat16 np.testing.assert_allclose(t.numpy(), np.ones((3, 5))) def test_bf16_eye(self): - # TODO: fix this with correct bfloat16 cast t = Tensor.eye(3, dtype=dtypes.bfloat16) assert t.dtype == dtypes.bfloat16 np.testing.assert_allclose(t.numpy(), np.eye(3)) @@ -161,8 +159,6 @@ class TestBFloat16DType(unittest.TestCase): def test_float_to_bf16(self): _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16) - # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16) - def test_bf16(self): t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16) t.realize()