mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
raise RuntimeError for uneven shards in Tensor.shard [pr] (#8656)
This commit is contained in:
@@ -56,6 +56,11 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert lb.shape == (128,)
|
||||
(X + X).realize()
|
||||
|
||||
def test_shard_not_multiple(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
with self.assertRaises(RuntimeError):
|
||||
X.shard_(devices_3, 0)
|
||||
|
||||
def test_tensor_from_multi(self):
|
||||
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
|
||||
Y = Tensor(X.lazydata)
|
||||
|
||||
@@ -408,7 +408,8 @@ class Tensor(SimpleMathTrait):
|
||||
if axis is None: lbs = [self.lazydata] * len(devices)
|
||||
else:
|
||||
axis = self._resolve_dim(axis)
|
||||
sz = ceildiv(self.shape[axis], len(devices))
|
||||
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
|
||||
sz = self.shape[axis] // len(devices)
|
||||
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
|
||||
lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)]
|
||||
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
|
||||
|
||||
Reference in New Issue
Block a user