From 9b2b535fa470e341fbfcb06e07ee034d7b755ef7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:28:50 -0800 Subject: [PATCH] fix issue with multi flip (#13115) --- test/test_multitensor.py | 6 ++++++ tinygrad/mixin/movement.py | 1 + tinygrad/schedule/multi.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index f987676dbc..bca243c5e5 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -596,6 +596,12 @@ class TestMultiTensor(unittest.TestCase): # ast are the same on devices self.assertEqual(len(set(asts)), 1) + def test_flip(self): + rng = Tensor.rand((10, 10, 10)) + t0 = rng.shard(devices_2, axis=1) + out = t0.flip(0) + 1 + self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device))) + def test_reshape_on_axis(self): t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1) diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py index faecfee2a1..a171f21767 100644 --- a/tinygrad/mixin/movement.py +++ b/tinygrad/mixin/movement.py @@ -117,6 +117,7 @@ class MovementMixin: ``` """ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) + assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}" if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") flip_arg = tuple([i in axis_arg for i in range(len(self.shape))]) return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 3eda6f58b8..a665bca837 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -186,7 +186,7 @@ def shrink_multi(root:UOp, multi:UOp): def flip_multi(root:UOp, multi:UOp): assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis" - return multi.src[0].flip(root.marg).multi(multi.axis) + return multi.src[0].flip([i for i,x in enumerate(root.marg) if x]).multi(multi.axis) # from multiple devices -> one def copy_multi(multi:UOp, device:UOp):