support for non-uniform sharding (#3154)

* support for non-uniform sharding

* bugfix and more tests

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
chenyu
2024-01-16 20:33:32 -05:00
committed by GitHub
parent 81ae4ea179
commit 14c010958b
3 changed files with 29 additions and 17 deletions

View File

@@ -295,5 +295,17 @@ class TestMultiTensor(unittest.TestCase):
assert isinstance(jf.jit_cache[4].prg, _BufferCopy)
assert isinstance(jf.jit_cache[5].prg, graph_d1)
def test_uneven_shard(self):
for N in range(1, 6):
X = Tensor.rand(4, 1, 257).contiguous().realize()
n = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), n)
np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
if __name__ == '__main__':
unittest.main()

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Optional, Union, Any, Tuple, List
import functools
from tinygrad.helpers import all_same, dedup
from tinygrad.helpers import all_same, dedup, round_up, DEBUG
from tinygrad.dtype import DType
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.lazy import LazyBuffer, create_schedule
@@ -12,16 +12,16 @@ def all_reduce(lbs):
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}"
sz = lbs[0].shape[axis] // len(lbs)
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
class MultiLazyBuffer:
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]):
assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
#assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st"
self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs)
self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))
self.shape = tuple(sum(y.shape[a] for y in self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape))
def __repr__(self):
return f"<MLB{chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
@@ -36,7 +36,7 @@ class MultiLazyBuffer:
sz = self.lbs[0].shape[self.axis]
llbs = []
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]):
pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape))
pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
llbs.append(lb.pad(pad_arg))
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
@@ -64,14 +64,15 @@ class MultiLazyBuffer:
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis)
def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape))
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
if self.axis is not None and new_shape[self.axis] == 1:
# all-reduce on sharded axes
return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis)
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis)
# *** movement ops ***
@@ -81,23 +82,22 @@ class MultiLazyBuffer:
st = ShapeTracker.from_shape(self.shape)
rs = st.real_strides()[self.axis]
new_axis = st.reshape(arg).real_strides().index(rs)
narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg))
return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis)
return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], new_axis)
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis"
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis)
def expand(self, arg:Tuple[sint, ...]):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis"
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis)
assert self.axis is None or arg[self.axis] == self.shape[self.axis], "expand not supported on sharded axis"
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis)
def permute(self, arg:Tuple[int, ...]):
# all permutes supported!
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None)
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis"
narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))
return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis)
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]), "shrinking not supported on sharded axis"
return MultiLazyBuffer(
[x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))) for x in self.lbs], self.axis)
def stride(self, arg:Tuple[int, ...]):
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis)

View File

@@ -122,7 +122,7 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
if any(b or e for b, e in arg):
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])