mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
to and shard is noop for deviceless uop (#16247)
This commit is contained in:
@@ -36,6 +36,16 @@ class TestTinygrad(unittest.TestCase):
|
||||
self.assertTrue(t.uop.has_buffer_identity())
|
||||
np.testing.assert_equal(t.numpy(), 2.0)
|
||||
|
||||
def test_to_deviceless_const(self):
|
||||
t = Tensor(UOp.const(dtypes.float, 2.0))
|
||||
self.assertIs(t.to(f"{Device.DEFAULT}:1"), t)
|
||||
self.assertIs(t.to_(f"{Device.DEFAULT}:1"), t)
|
||||
|
||||
def test_shard_deviceless_const(self):
|
||||
t = Tensor(UOp.const(dtypes.float, 2.0))
|
||||
self.assertIs(t.shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")), t)
|
||||
self.assertIs(t.shard_((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")), t)
|
||||
|
||||
def test_plus_equals(self):
|
||||
a = Tensor.randn(10,10)
|
||||
b = Tensor.randn(10,10)
|
||||
|
||||
@@ -364,6 +364,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
Moves the tensor to the given device.
|
||||
"""
|
||||
if self.uop._device is None: return self
|
||||
if (device:=canonicalize_device(device)) == self.device: return self
|
||||
ret = Tensor(self.uop.copy_to_device(device), requires_grad=self.requires_grad)
|
||||
if self.grad is not None: ret.grad = self.grad.to(device)
|
||||
@@ -386,6 +387,7 @@ class Tensor(OpMixin):
|
||||
print(t.shard((t.device, t.device), axis=1).uop)
|
||||
```
|
||||
"""
|
||||
if self.uop._device is None: return self
|
||||
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
|
||||
if len(devices) == 1: return self.to(devices[0])
|
||||
devices = cast(tuple[str, ...], canonicalize_device(devices))
|
||||
|
||||
Reference in New Issue
Block a user