deviceless const skip axis check (#16496)

This commit is contained in:
wozeparrot
2026-06-03 22:13:20 -04:00
committed by GitHub
parent f7f03bd7e5
commit fd13080636
3 changed files with 13 additions and 2 deletions

View File

@@ -1476,7 +1476,7 @@ def train_llama3():
grad_norm = optim.fstep(grads)
scheduler.step()
for g in grads: g.assign(g.const_like(0))
for g in grads: g.assign(0)
lr_cpu = optim.lr.float().to("CPU")
grad_norm_cpu = grad_norm.float().to("CPU")

View File

@@ -1247,6 +1247,16 @@ class TestMultiAssign(unittest.TestCase):
out.assign(ones).realize()
self.assertListEqual(out.tolist(), [1,1,1,1])
def test_multi_assign_scalar(self):
out = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(0).realize()
self.assertListEqual(out.tolist(), [0,0,0,0])
def test_multi_assign_const_like(self):
out = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(out.const_like(7)).realize()
self.assertListEqual(out.tolist(), [7,7,7,7])
def test_multi_assign_piece(self):
out = Tensor.zeros(4,4).shard(self.device, 0).contiguous().realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()

View File

@@ -263,7 +263,8 @@ class Tensor(OpMixin):
if not is_disk and x.uop.device is not None and self.device is not None and self.device != x.device:
raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
if not is_disk and self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
if isinstance(self.device, tuple) and x.uop.device is not None and self.uop.axis != x.uop.axis:
raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
# TODO: this is a hack for writing to DISK. remove with working assign
if is_disk: