mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
support for non-uniform sharding (#3154)
* support for non-uniform sharding * bugfix and more tests --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -295,5 +295,17 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert isinstance(jf.jit_cache[4].prg, _BufferCopy)
|
||||
assert isinstance(jf.jit_cache[5].prg, graph_d1)
|
||||
|
||||
def test_uneven_shard(self):
|
||||
for N in range(1, 6):
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
n = X.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
X.shard_(devices, 2)
|
||||
np.testing.assert_equal(X.numpy(), n)
|
||||
np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
|
||||
np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
|
||||
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
||||
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Union, Any, Tuple, List
|
||||
import functools
|
||||
from tinygrad.helpers import all_same, dedup
|
||||
from tinygrad.helpers import all_same, dedup, round_up, DEBUG
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
@@ -12,16 +12,16 @@ def all_reduce(lbs):
|
||||
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
||||
assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}"
|
||||
sz = lbs[0].shape[axis] // len(lbs)
|
||||
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
||||
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
||||
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
|
||||
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
||||
|
||||
class MultiLazyBuffer:
|
||||
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]):
|
||||
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
||||
assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
|
||||
#assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
|
||||
self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs)
|
||||
self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))
|
||||
self.shape = tuple(sum(y.shape[a] for y in self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MLB{chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
||||
@@ -36,7 +36,7 @@ class MultiLazyBuffer:
|
||||
sz = self.lbs[0].shape[self.axis]
|
||||
llbs = []
|
||||
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]):
|
||||
pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape))
|
||||
pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
|
||||
llbs.append(lb.pad(pad_arg))
|
||||
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
||||
|
||||
@@ -64,14 +64,15 @@ class MultiLazyBuffer:
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
|
||||
return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis)
|
||||
|
||||
def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape))
|
||||
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
||||
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
||||
|
||||
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
|
||||
if self.axis is not None and new_shape[self.axis] == 1:
|
||||
# all-reduce on sharded axes
|
||||
return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis)
|
||||
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis)
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
@@ -81,23 +82,22 @@ class MultiLazyBuffer:
|
||||
st = ShapeTracker.from_shape(self.shape)
|
||||
rs = st.real_strides()[self.axis]
|
||||
new_axis = st.reshape(arg).real_strides().index(rs)
|
||||
narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg))
|
||||
return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis)
|
||||
return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], new_axis)
|
||||
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
||||
assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis"
|
||||
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis)
|
||||
def expand(self, arg:Tuple[sint, ...]):
|
||||
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
||||
assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis"
|
||||
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis)
|
||||
assert self.axis is None or arg[self.axis] == self.shape[self.axis], "expand not supported on sharded axis"
|
||||
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis)
|
||||
def permute(self, arg:Tuple[int, ...]):
|
||||
# all permutes supported!
|
||||
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None)
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
||||
assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis"
|
||||
narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))
|
||||
return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis)
|
||||
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]), "shrinking not supported on sharded axis"
|
||||
return MultiLazyBuffer(
|
||||
[x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))) for x in self.lbs], self.axis)
|
||||
def stride(self, arg:Tuple[int, ...]):
|
||||
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
||||
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis)
|
||||
|
||||
@@ -122,7 +122,7 @@ class View:
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
|
||||
if any(b or e for b, e in arg):
|
||||
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
||||
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
|
||||
|
||||
Reference in New Issue
Block a user