raise RuntimeError for uneven shards in Tensor.shard [pr] (#8656)

This commit is contained in:
chenyu
2025-01-17 12:48:39 -05:00
committed by GitHub
parent 3506a7585f
commit f8cc971c3b
2 changed files with 7 additions and 1 deletions

View File

@@ -56,6 +56,11 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
def test_shard_not_multiple(self):
X = Tensor.ones(256).contiguous().realize()
with self.assertRaises(RuntimeError):
X.shard_(devices_3, 0)
def test_tensor_from_multi(self):
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
Y = Tensor(X.lazydata)

View File

@@ -408,7 +408,8 @@ class Tensor(SimpleMathTrait):
if axis is None: lbs = [self.lazydata] * len(devices)
else:
axis = self._resolve_dim(axis)
sz = ceildiv(self.shape[axis], len(devices))
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
sz = self.shape[axis] // len(devices)
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)]
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]