From 14c010958b089a2eceb2e9b8bbad2eaaff1718c3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 16 Jan 2024 20:33:32 -0500 Subject: [PATCH] support for non-uniform sharding (#3154) * support for non-uniform sharding * bugfix and more tests --------- Co-authored-by: George Hotz --- test/test_multitensor.py | 12 ++++++++++++ tinygrad/features/multi.py | 32 ++++++++++++++++---------------- tinygrad/shape/view.py | 2 +- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 66115a2aaa..7dc75b6a31 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -295,5 +295,17 @@ class TestMultiTensor(unittest.TestCase): assert isinstance(jf.jit_cache[4].prg, _BufferCopy) assert isinstance(jf.jit_cache[5].prg, graph_d1) + def test_uneven_shard(self): + for N in range(1, 6): + X = Tensor.rand(4, 1, 257).contiguous().realize() + n = X.numpy() + devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N)) + X.shard_(devices, 2) + np.testing.assert_equal(X.numpy(), n) + np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257))) + np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257]) + np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1))) + np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1))) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 2115afa85b..a90079898b 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional, Union, Any, Tuple, List import functools -from tinygrad.helpers import all_same, dedup +from tinygrad.helpers import all_same, dedup, round_up, DEBUG from tinygrad.dtype import DType from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps from tinygrad.lazy import LazyBuffer, create_schedule @@ -12,16 +12,16 @@ def all_reduce(lbs): return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]: - assert lbs[0].shape[axis] % len(lbs) == 0, f"{lbs[0].shape=} {axis=} {len(lbs)=}" - sz = lbs[0].shape[axis] // len(lbs) - return [lb.shrink(tuple((0,s) if a != axis else (sz*i,sz*(i+1)) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)] + if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}") + sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs) + return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)] class MultiLazyBuffer: def __init__(self, lbs:List[LazyBuffer], axis:Optional[int]): assert all(isinstance(x, LazyBuffer) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them" - assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st" + #assert all_same([(x.shape, x.dtype, x.st) for x in lbs]), "all multilazybuffer needs same shape, dtype, and st" self.lbs, self.axis, self.dtype, self.device = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs) - self.shape = tuple(s*len(self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape)) + self.shape = tuple(sum(y.shape[a] for y in self.lbs) if a == self.axis else s for a,s in enumerate(lbs[0].shape)) def __repr__(self): return f"" @@ -36,7 +36,7 @@ class MultiLazyBuffer: sz = self.lbs[0].shape[self.axis] llbs = [] for i,lb in enumerate([lb.copy_to_device(device) for lb in self.lbs]): - pad_arg = tuple((0,0) if a != self.axis else (sz*i,(s*len(self.lbs))-sz*(i+1)) for a,s in enumerate(lb.shape)) + pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape))) llbs.append(lb.pad(pad_arg)) return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs) @@ -64,14 +64,15 @@ class MultiLazyBuffer: else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis)) return MultiLazyBuffer([lsrcs[0].e(op, *lsrcs[1:], arg=arg) for lsrcs in zip(*srcs)], axis) - def _shape_to_single_shard(self, shape): return tuple(s//len(self.lbs) if a == self.axis else s for a,s in enumerate(shape)) + def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: + return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer: if self.axis is not None and new_shape[self.axis] == 1: # all-reduce on sharded axes return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape)) for x in self.lbs], self.axis) + return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis) # *** movement ops *** @@ -81,23 +82,22 @@ class MultiLazyBuffer: st = ShapeTracker.from_shape(self.shape) rs = st.real_strides()[self.axis] new_axis = st.reshape(arg).real_strides().index(rs) - narg = tuple(s//len(self.lbs) if a == new_axis else s for a,s in enumerate(arg)) - return MultiLazyBuffer([x.reshape(narg) for x in self.lbs], new_axis) + return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], new_axis) def pad(self, arg:Tuple[Tuple[sint, sint], ...]): assert self.axis is None or arg[self.axis] == (0,0), "padding not supported on sharded axis" return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis) def expand(self, arg:Tuple[sint, ...]): # NOTE: this assert isn't needed, sharded axis can have dim 1 - assert self.axis is None or arg[self.axis] == self.lbs[0].shape[self.axis] * len(self.lbs), "expand not supported on sharded axis" - return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg)) for x in self.lbs], self.axis) + assert self.axis is None or arg[self.axis] == self.shape[self.axis], "expand not supported on sharded axis" + return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis) def permute(self, arg:Tuple[int, ...]): # all permutes supported! return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None) def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0, self.lbs[0].shape[self.axis] * len(self.lbs)), "shrinking not supported on sharded axis" - narg = tuple((s1//len(self.lbs), s2//len(self.lbs)) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg)) - return MultiLazyBuffer([x.shrink(narg) for x in self.lbs], self.axis) + assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]), "shrinking not supported on sharded axis" + return MultiLazyBuffer( + [x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else (s1,s2) for a,(s1,s2) in enumerate(arg))) for x in self.lbs], self.axis) def stride(self, arg:Tuple[int, ...]): assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 1784ab8cb7..1900b9028e 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -122,7 +122,7 @@ class View: @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View: - assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape) + assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}" if any(b or e for b, e in arg): zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)]) mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])