fix issue with multi flip (#13115)

This commit is contained in:
George Hotz
2025-11-05 15:28:50 -08:00
committed by GitHub
parent 4027eef264
commit 9b2b535fa4
3 changed files with 8 additions and 1 deletions

View File

@@ -596,6 +596,12 @@ class TestMultiTensor(unittest.TestCase):
# ast are the same on devices
self.assertEqual(len(set(asts)), 1)
def test_flip(self):
rng = Tensor.rand((10, 10, 10))
t0 = rng.shard(devices_2, axis=1)
out = t0.flip(0) + 1
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)))
def test_reshape_on_axis(self):
t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1)

View File

@@ -117,6 +117,7 @@ class MovementMixin:
```
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self

View File

@@ -186,7 +186,7 @@ def shrink_multi(root:UOp, multi:UOp):
def flip_multi(root:UOp, multi:UOp):
assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis"
return multi.src[0].flip(root.marg).multi(multi.axis)
return multi.src[0].flip([i for i,x in enumerate(root.marg) if x]).multi(multi.axis)
# from multiple devices -> one
def copy_multi(multi:UOp, device:UOp):