diff --git a/test/test_dtype.py b/test/test_dtype.py index 8f71cfd4f1..7d45ea051b 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -566,7 +566,21 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16 assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16 -class TestImplicitFunctionTypeChange(unittest.TestCase): + @unittest.skipIf(not is_dtype_supported(dtypes.float16), "need float16") + def test_gradient_dtype(self): + for default_dtype in [dtypes.float16, dtypes.float32]: + old_default_float = dtypes.default_float + try: + dtypes.default_float = default_dtype + for datatype in [dtypes.float16, dtypes.float32]: + a = Tensor([1, 2, 3], dtype=datatype, requires_grad=True) + b = (a * 5).sum() + b.backward() # if there is dtype mismatch, lazy should assert + assert a.grad.dtype == a.dtype + np.testing.assert_allclose(a.grad.numpy(), Tensor([5, 5, 5], dtype=datatype).numpy()) + finally: + dtypes.default_float = old_default_float + def test_functions(self): result = [] for func in [ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 365dfa248a..266019d48d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -343,7 +343,7 @@ class Tensor: # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous # this is "implicit gradient creation" - self.grad = Tensor(1.0, device=self.device, requires_grad=False) + self.grad = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk()): if t0.grad is None: raise RuntimeError("tensor has no grad")