mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix issue with multi flip (#13115)
This commit is contained in:
@@ -596,6 +596,12 @@ class TestMultiTensor(unittest.TestCase):
|
||||
# ast are the same on devices
|
||||
self.assertEqual(len(set(asts)), 1)
|
||||
|
||||
def test_flip(self):
|
||||
rng = Tensor.rand((10, 10, 10))
|
||||
t0 = rng.shard(devices_2, axis=1)
|
||||
out = t0.flip(0) + 1
|
||||
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)))
|
||||
|
||||
def test_reshape_on_axis(self):
|
||||
t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1)
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ class MovementMixin:
|
||||
```
|
||||
"""
|
||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||
assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
|
||||
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
|
||||
|
||||
@@ -186,7 +186,7 @@ def shrink_multi(root:UOp, multi:UOp):
|
||||
|
||||
def flip_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis"
|
||||
return multi.src[0].flip(root.marg).multi(multi.axis)
|
||||
return multi.src[0].flip([i for i,x in enumerate(root.marg) if x]).multi(multi.axis)
|
||||
|
||||
# from multiple devices -> one
|
||||
def copy_multi(multi:UOp, device:UOp):
|
||||
|
||||
Reference in New Issue
Block a user