From 84a2e2b8c10fa0d2e6d2e5cb1b02fa3bfa37f362 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Fri, 10 May 2024 17:15:12 +0300 Subject: [PATCH] hcq graph (#4380) * start hcq graph * hack-fix sync on amd * nv * fix nv * multigrah * fixes * temp fix for graph * this is not needed * fix * cleaner * linetr * fix none * faster cuda copy * faster amd copy * temp nv fixes * alloc on gpu * exp: faster amd * Revert "exp: faster amd" This reverts commit 2e4cfd1f7d8a33634c50fb5655cff1b40269d28c. * revert, unrelated * not in this pr * linter --- tinygrad/buffer.py | 1 + tinygrad/engine/jit.py | 8 +- tinygrad/runtime/graph/hcq.py | 139 ++++++++++++++++++++++++++++++++++ tinygrad/runtime/ops_amd.py | 6 +- tinygrad/runtime/ops_nv.py | 15 ++-- 5 files changed, 159 insertions(+), 10 deletions(-) create mode 100644 tinygrad/runtime/graph/hcq.py diff --git a/tinygrad/buffer.py b/tinygrad/buffer.py index 802c9b6974..fc9d739ed4 100644 --- a/tinygrad/buffer.py +++ b/tinygrad/buffer.py @@ -10,6 +10,7 @@ from tinygrad.dtype import DType, ImageDType class BufferOptions: image: Optional[ImageDType] = None uncached: bool = False + cpu_access: bool = False host: bool = False nolru: bool = False diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index c4c7eac2d0..471934a123 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -41,12 +41,14 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] if ji.prg.__class__ in {EmptyOp, ViewOp}: continue ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device - elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}: + elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}: ji_graph_dev = Device[ji.bufs[0].device] + graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore can_be_graphed = ji_graph_dev and ji_graph_dev.graph - can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or - (isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiGraphRunner) and type(ji_graph_dev) == type(current_device))) #type:ignore + can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and + type(ji_graph_dev) == type(current_device)) + can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph if not can_extend_graph_batch and len(current_batch) > 0: flush_batch() if can_be_graphed: current_batch.append(ji) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py new file mode 100644 index 0000000000..a113698c52 --- /dev/null +++ b/tinygrad/runtime/graph/hcq.py @@ -0,0 +1,139 @@ +import ctypes, collections, array, time +from typing import List, Any, Dict, cast, Optional, Tuple, Set +from tinygrad.helpers import GraphException, round_up, to_mv +from tinygrad.buffer import Buffer, BufferOptions +from tinygrad.device import Compiled, CompiledRunner, BufferXfer, Device +from tinygrad.shape.symbolic import Variable +from tinygrad.engine.realize import ExecItem +from tinygrad.engine.jit import MultiGraphRunner + +class HCQGraph(MultiGraphRunner): + def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + super().__init__(jit_cache, input_rawbuffers, var_vals) + self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t + + # Check all jit items are compatible. + self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore + if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException + + # Allocate kernel args. + kernargs_size: Dict[Compiled, int] = collections.defaultdict(int) + for ji in self.jit_cache: + kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16) if isinstance(ji.prg, CompiledRunner) else 0 + kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()} + + # Fill initial arguments. + self.kargs_addrs: Dict[int, int] = {} + self.ji_kargs_structs: Dict[int, ctypes.Structure] = {} + for j,ji in enumerate(self.jit_cache): + if not isinstance(ji.prg, CompiledRunner): continue + self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device] + kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16) + + self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset) + for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr) + for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]]) + + # NV needs constbuffer to be set + if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0) + + # Build queues. + self.queue_list: List[Tuple[Any, ...]] = [] + + self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t) + self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices} + self.comp_signal_val = {dev: 0 for dev in self.devices} + + self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t) + self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices} + self.copy_signal_val = {dev: 0 for dev in self.devices} + + self.kickoff_signal = self.devices[0]._get_signal(value=0) + self.kickoff_value = 0 + self.graph_timeline = {dev: 0 for dev in self.devices} + + self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices} + + for j,ji in enumerate(self.jit_cache): + if isinstance(ji.prg, CompiledRunner): + deps = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1)) + deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device])) + self.comp_signal_val[ji.prg.device] = sig_val + + # Rebuilt runners with dynamic launch dims online. + if j in self.jc_idx_with_updatable_launch_dims: + if ji.prg.device in self.comp_queues: self.queue_list.append((self.comp_queues.pop(ji.prg.device), ji.prg.device)) + self.queue_list.append((j, deps)) + else: + for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val) + self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.launch_dims(var_vals)) \ + .signal(self.comp_signal[ji.prg.device], sig_val) + elif isinstance(ji.prg, BufferXfer): + dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] + Device[src.device]._gpu_map(dest._buf) #type: ignore + + deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1)) + deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]])) + self.copy_signal_val[Device[src.device]] = sig_val + + for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val) + self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \ + .signal(self.copy_signal[Device[src.device]], sig_val) + self.copy_to_devs[Device[dest.device]].add(Device[src.device]) + + for dev in self.devices: + if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev]) + for dep_dev in self.copy_to_devs: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev]) + + self.queue_list.append((self.comp_queues.pop(dev), dev)) + if self.copy_signal_val[dev] > 0: self.queue_list.append((self.copy_queues.pop(dev), dev)) + + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: + # Wait and restore signals + self.kickoff_value += 1 + for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev]) + for dev in self.devices: + dev._set_signal(self.comp_signal[dev], 0) + dev._set_signal(self.copy_signal[dev], 0) + dev._set_signal(self.kickoff_signal, self.kickoff_value) + + # Update rawbuffers + for (j,i),input_idx in self.input_replace.items(): + self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr) + + # Update var_vals + for j in self.jc_idx_with_updatable_var_vals: + for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars): + self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v]) + + for dev in self.devices: + self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \ + .wait(self.kickoff_signal, self.kickoff_value).submit(dev) + self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \ + .wait(self.kickoff_signal, self.kickoff_value).submit(dev) + + for entry in self.queue_list: + if isinstance(entry[0], self.comp_hcq_t) or isinstance(entry[0], self.copy_hcq_t): queue, dev = entry + else: + # Kernel with dynamic launch bounds, rebuild it. + j, ji, deps, dev = entry[0], self.jit_cache[entry[0]], entry[1], self.jit_cache[entry[0]].prg.device + queue = self.comp_hcq_t() + for sig, val in deps: queue.wait(sig, val) + queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.launch_dims(var_vals)) \ + .signal(self.comp_signal[dev], value=j+1) + queue.submit(dev) + + for dev in self.devices: + self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev) + self.graph_timeline[dev] = dev.timeline_value + dev.timeline_value += 1 + + if wait: + st = time.perf_counter() + for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev]) + return time.perf_counter() - st + return None + + def access_resources(self, read, write, new_dependency): + deps = self._access_resources(read, write, new_dependency) + return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 77949e93b3..f533f72e0b 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -388,7 +388,7 @@ class AMDAllocator(LRUAllocator): def _alloc(self, size:int, options:BufferOptions): try: if options.host: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True) - else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=True) + else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access) except OSError as e: if e.errno == errno.ENOMEM: raise MemoryError("Cannot allocate memory") from e else: raise @@ -584,7 +584,9 @@ class AMDDevice(Compiled): self.pm4_write_pointer = to_mv(self.pm4_queue.write_pointer_address, 8).cast("Q") self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q") - super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self)) + from tinygrad.runtime.graph.hcq import HCQGraph + super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self), + functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue)) def synchronize(self): AMDDevice._wait_signal(self.timeline_signal, self.timeline_value - 1) diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index 0954d2f336..a4259cbe55 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -1,5 +1,5 @@ from __future__ import annotations -import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time +import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array from typing import Tuple, List, Any, cast from dataclasses import replace from tinygrad.device import Compiled, Compiler, CompilerOptions @@ -116,7 +116,7 @@ class HWComputeQueue: def submit(self, dev:NVDevice): if len(self.q) == 0: return assert len(self.q) < (1 << 21) - for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet + dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q) fifo_entry = dev.compute_put_value % dev.compute_gpfifo_entries dev.compute_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42) | (1 << 41) dev.compute_gpu_ring_controls.GPPut = (dev.compute_put_value + 1) % dev.compute_gpfifo_entries @@ -146,7 +146,7 @@ class HWCopyQueue: def submit(self, dev:NVDevice): if len(self.q) == 0: return - for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet + dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q) fifo_entry = dev.dma_put_value % dev.dma_gpfifo_entries dev.dma_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42) dev.dma_gpu_ring_controls.GPPut = (dev.dma_put_value + 1) % dev.dma_gpfifo_entries @@ -235,6 +235,9 @@ class NVProgram: def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch") + if not hasattr(self, "args_struct_t"): + self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] + + [(f'v{i}', ctypes.c_int) for i in range(len(vals))])) if self.device.kernargs_ptr >= (self.device.kernargs_page.base + self.device.kernargs_page.length - self.kernargs_segment_size): self.device.kernargs_ptr = self.device.kernargs_page.base @@ -265,7 +268,7 @@ class NVAllocator(LRUAllocator): def _alloc(self, size:int, options:BufferOptions): if options.host: return self.device._gpu_host_alloc(size) - else: return self.device._gpu_alloc(size) + else: return self.device._gpu_alloc(size, map_to_cpu=options.cpu_access) def _free(self, gpumem, options:BufferOptions): NVDevice.synchronize_system() @@ -491,7 +494,9 @@ class NVDevice(Compiled): self.arch: str = 'sm_89' # TODO: fix - super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self)) + from tinygrad.runtime.graph.hcq import HCQGraph + super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self), + functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue)) self._cmdq_setup_compute_gpfifo() self._cmdq_setup_dma_gpfifo()