diff --git a/test/test_dtype.py b/test/test_dtype.py index 4036dbc175..c4126a66ea 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -42,7 +42,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target): raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) -def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist()) +def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np))) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist()) class TestDType(unittest.TestCase):