diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 77e6e8415a..366b2aa9a8 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -1,7 +1,8 @@ from __future__ import annotations import math from typing import Optional, Union, Tuple, List -from tinygrad.tensor import Tensor +from tinygrad.tensor import Tensor, dtypes +from tinygrad.device import is_dtype_supported from tinygrad.helpers import prod, make_tuple, flatten from tinygrad.nn import optim, state, datasets # noqa: F401 @@ -36,7 +37,7 @@ class BatchNorm: self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None - self.num_batches_tracked = Tensor.zeros(1, dtype='long', requires_grad=False) + self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False) if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False) def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]: