to and shard is noop for deviceless uop (#16247)

This commit is contained in:
chenyu
2026-05-18 16:11:10 -04:00
committed by GitHub
parent 50481ec9b4
commit 73e6b4963b
2 changed files with 12 additions and 0 deletions

View File

@@ -36,6 +36,16 @@ class TestTinygrad(unittest.TestCase):
self.assertTrue(t.uop.has_buffer_identity())
np.testing.assert_equal(t.numpy(), 2.0)
def test_to_deviceless_const(self):
t = Tensor(UOp.const(dtypes.float, 2.0))
self.assertIs(t.to(f"{Device.DEFAULT}:1"), t)
self.assertIs(t.to_(f"{Device.DEFAULT}:1"), t)
def test_shard_deviceless_const(self):
t = Tensor(UOp.const(dtypes.float, 2.0))
self.assertIs(t.shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")), t)
self.assertIs(t.shard_((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")), t)
def test_plus_equals(self):
a = Tensor.randn(10,10)
b = Tensor.randn(10,10)

View File

@@ -364,6 +364,7 @@ class Tensor(OpMixin):
"""
Moves the tensor to the given device.
"""
if self.uop._device is None: return self
if (device:=canonicalize_device(device)) == self.device: return self
ret = Tensor(self.uop.copy_to_device(device), requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
@@ -386,6 +387,7 @@ class Tensor(OpMixin):
print(t.shard((t.device, t.device), axis=1).uop)
```
"""
if self.uop._device is None: return self
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
if len(devices) == 1: return self.to(devices[0])
devices = cast(tuple[str, ...], canonicalize_device(devices))