mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
deviceless const skip axis check (#16496)
This commit is contained in:
@@ -1476,7 +1476,7 @@ def train_llama3():
|
||||
grad_norm = optim.fstep(grads)
|
||||
scheduler.step()
|
||||
|
||||
for g in grads: g.assign(g.const_like(0))
|
||||
for g in grads: g.assign(0)
|
||||
|
||||
lr_cpu = optim.lr.float().to("CPU")
|
||||
grad_norm_cpu = grad_norm.float().to("CPU")
|
||||
|
||||
@@ -1247,6 +1247,16 @@ class TestMultiAssign(unittest.TestCase):
|
||||
out.assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [1,1,1,1])
|
||||
|
||||
def test_multi_assign_scalar(self):
|
||||
out = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
|
||||
out.assign(0).realize()
|
||||
self.assertListEqual(out.tolist(), [0,0,0,0])
|
||||
|
||||
def test_multi_assign_const_like(self):
|
||||
out = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
|
||||
out.assign(out.const_like(7)).realize()
|
||||
self.assertListEqual(out.tolist(), [7,7,7,7])
|
||||
|
||||
def test_multi_assign_piece(self):
|
||||
out = Tensor.zeros(4,4).shard(self.device, 0).contiguous().realize()
|
||||
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
|
||||
|
||||
@@ -263,7 +263,8 @@ class Tensor(OpMixin):
|
||||
if not is_disk and x.uop.device is not None and self.device is not None and self.device != x.device:
|
||||
raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
|
||||
if not is_disk and self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
|
||||
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
|
||||
if isinstance(self.device, tuple) and x.uop.device is not None and self.uop.axis != x.uop.axis:
|
||||
raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
|
||||
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if is_disk:
|
||||
|
||||
Reference in New Issue
Block a user