diff --git a/test/test_remote.py b/test/test_remote.py index 47ba007247..796f98e62d 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -32,7 +32,6 @@ class TestRemoteMultiHost(unittest.TestCase): # Verify that everything is in one big cross-host graph assert len(do.captured._jit_cache) == 1 and isinstance(do.captured._jit_cache[0].prg, RemoteGraph), repr(do.captured) - @unittest.expectedFailure # multihost-aware schedule is in separate pr @Context(JIT_BATCH_SIZE=2**32) def test_multihost_aware_schedule(self): @TinyJit diff --git a/tinygrad/device.py b/tinygrad/device.py index 6f9f3999a6..ffe89386b9 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -278,9 +278,9 @@ class Compiler: class Compiled: profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device. - def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None): + def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None): self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph - self.renderer = renderer or Renderer() + self.renderer, self.group_id = renderer or Renderer(), group_id def synchronize(self): """ Synchronize all pending operations on the device. diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9a64639bf2..a4e3ec4a09 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,8 +2,8 @@ from typing import cast from dataclasses import dataclass, field from collections import deque, defaultdict from tinygrad.uop.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers -from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import Metadata, unwrap, merge_dicts +from tinygrad.device import Device, Buffer, MultiBuffer +from tinygrad.helpers import Metadata, unwrap, all_same, merge_dicts # **** ScheduleItem return type @@ -61,11 +61,22 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}") # linearize KERNEL UOps into ScheduleItems in BFS order - queue = deque(k for k,v in in_degree.items() if v == 0) + + def _heuristic(k: UOp): + if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000 + return 0 + + last_heuristic: int = 0 + queues: defaultdict[int, deque[UOp]] = defaultdict(deque) + last_queue: deque[UOp] = deque() + for k,v in in_degree.items(): + if v == 0: queues[_heuristic(k)].append(k) + schedule: list[ScheduleItem] = [] var_vals: dict[Variable, int] = {} - while queue: - k = queue.popleft() + while last_queue or any(queues.values()): + if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic)) + k = last_queue.popleft() # unbind var_vals from the kernel local_var_vals: list[dict[Variable, int]] = [] ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=local_var_vals, name="unbind vars") @@ -86,6 +97,6 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata)) for x in children[k]: in_degree[x] -= 1 - if in_degree[x] == 0: queue.append(x) + if in_degree[x] == 0: queues[_heuristic(x)].append(x) return schedule, var_vals diff --git a/tinygrad/kernelize/multi.py b/tinygrad/kernelize/multi.py index e0d065d68d..a0f8135891 100644 --- a/tinygrad/kernelize/multi.py +++ b/tinygrad/kernelize/multi.py @@ -2,8 +2,30 @@ from typing import cast import functools, itertools, operator from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve +from tinygrad.device import Device # *** allreduce implementation *** +def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None: + if not isinstance(buf.device, tuple): return None + + # Group buffers + groups: dict[int|None, list[UOp]] = {} + for i,dev in enumerate(buf.device): + groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i)) + + # Skip if only one group or if every group has only one buffer + if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None + + # Reduce inside each group + inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()] + + # Allreduce across groups + outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner)) + + # Broadcast back to all devices in the group + gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)} + return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \ + UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device)) def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: if not isinstance(buf.device, tuple): return None @@ -84,6 +106,7 @@ def mstack_early_shrink(view:UOp, ms:UOp): return ms.replace(src=tuple(ret)) replace_allreduce = PatternMatcher([ + (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank), (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce), # BROADCAST: explicitly expand broadcast copies and combine with MSTACK (UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x: diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index ad4b286e4e..cad43b532a 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -389,7 +389,7 @@ class RemoteDevice(Compiled): renderer_instance = renderer_class(*renderer[2]) renderer_instance.device = device graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None - super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph) + super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn)) def finalize(self): with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)