From f8cc971c3be9bcbfb2adb2f62ce3425671dbb4f3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 17 Jan 2025 12:48:39 -0500 Subject: [PATCH] raise RuntimeError for uneven shards in Tensor.shard [pr] (#8656) --- test/test_multitensor.py | 5 +++++ tinygrad/tensor.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 8e96e2ba59..128454c41b 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -56,6 +56,11 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() + def test_shard_not_multiple(self): + X = Tensor.ones(256).contiguous().realize() + with self.assertRaises(RuntimeError): + X.shard_(devices_3, 0) + def test_tensor_from_multi(self): X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0) Y = Tensor(X.lazydata) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e0a5dc0867..1a48e515de 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -408,7 +408,8 @@ class Tensor(SimpleMathTrait): if axis is None: lbs = [self.lazydata] * len(devices) else: axis = self._resolve_dim(axis) - sz = ceildiv(self.shape[axis], len(devices)) + if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}") + sz = self.shape[axis] // len(devices) sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))] lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)] sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]