diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 3da2a7792d..fe7c338173 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1476,7 +1476,7 @@ def train_llama3(): grad_norm = optim.fstep(grads) scheduler.step() - for g in grads: g.assign(g.const_like(0)) + for g in grads: g.assign(0) lr_cpu = optim.lr.float().to("CPU") grad_norm_cpu = grad_norm.float().to("CPU") diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index dd03d04cbd..0626138120 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -1247,6 +1247,16 @@ class TestMultiAssign(unittest.TestCase): out.assign(ones).realize() self.assertListEqual(out.tolist(), [1,1,1,1]) + def test_multi_assign_scalar(self): + out = Tensor.ones(4).shard(self.device, 0).contiguous().realize() + out.assign(0).realize() + self.assertListEqual(out.tolist(), [0,0,0,0]) + + def test_multi_assign_const_like(self): + out = Tensor.ones(4).shard(self.device, 0).contiguous().realize() + out.assign(out.const_like(7)).realize() + self.assertListEqual(out.tolist(), [7,7,7,7]) + def test_multi_assign_piece(self): out = Tensor.zeros(4,4).shard(self.device, 0).contiguous().realize() ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index bf8dc537bb..2154e2fb43 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -263,7 +263,8 @@ class Tensor(OpMixin): if not is_disk and x.uop.device is not None and self.device is not None and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}") if not is_disk and self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") - if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") + if isinstance(self.device, tuple) and x.uop.device is not None and self.uop.axis != x.uop.axis: + raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") # TODO: this is a hack for writing to DISK. remove with working assign if is_disk: