* start hcq graph

* hack-fix sync on amd

* nv

* fix nv

* multigrah

* fixes

* temp fix for graph

* this is not needed

* fix

* cleaner

* linetr

* fix none

* faster cuda copy

* faster amd copy

* temp nv fixes

* alloc on gpu

* exp: faster amd

* Revert "exp: faster amd"

This reverts commit 2e4cfd1f7d8a33634c50fb5655cff1b40269d28c.

* revert, unrelated

* not in this pr

* linter
This commit is contained in:
nimlgen
2024-05-10 17:15:12 +03:00
committed by GitHub
parent 2b7ab60584
commit 84a2e2b8c1
5 changed files with 159 additions and 10 deletions

View File

@@ -10,6 +10,7 @@ from tinygrad.dtype import DType, ImageDType
class BufferOptions:
image: Optional[ImageDType] = None
uncached: bool = False
cpu_access: bool = False
host: bool = False
nolru: bool = False

View File

@@ -41,12 +41,14 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}:
ji_graph_dev = Device[ji.bufs[0].device]
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
(isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiGraphRunner) and type(ji_graph_dev) == type(current_device))) #type:ignore
can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
type(ji_graph_dev) == type(current_device))
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
if can_be_graphed: current_batch.append(ji)

View File

@@ -0,0 +1,139 @@
import ctypes, collections, array, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import GraphException, round_up, to_mv
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
# Check all jit items are compatible.
self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
# Allocate kernel args.
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16) if isinstance(ji.prg, CompiledRunner) else 0
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
# Fill initial arguments.
self.kargs_addrs: Dict[int, int] = {}
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]])
# NV needs constbuffer to be set
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
# Build queues.
self.queue_list: List[Tuple[Any, ...]] = []
self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
self.comp_signal_val = {dev: 0 for dev in self.devices}
self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
self.copy_signal_val = {dev: 0 for dev in self.devices}
self.kickoff_signal = self.devices[0]._get_signal(value=0)
self.kickoff_value = 0
self.graph_timeline = {dev: 0 for dev in self.devices}
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
deps = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
self.comp_signal_val[ji.prg.device] = sig_val
# Rebuilt runners with dynamic launch dims online.
if j in self.jc_idx_with_updatable_launch_dims:
if ji.prg.device in self.comp_queues: self.queue_list.append((self.comp_queues.pop(ji.prg.device), ji.prg.device))
self.queue_list.append((j, deps))
else:
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.launch_dims(var_vals)) \
.signal(self.comp_signal[ji.prg.device], sig_val)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
Device[src.device]._gpu_map(dest._buf) #type: ignore
deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
self.copy_signal_val[Device[src.device]] = sig_val
for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
.signal(self.copy_signal[Device[src.device]], sig_val)
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
for dev in self.devices:
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
for dep_dev in self.copy_to_devs: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
self.queue_list.append((self.comp_queues.pop(dev), dev))
if self.copy_signal_val[dev] > 0: self.queue_list.append((self.copy_queues.pop(dev), dev))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
for dev in self.devices:
dev._set_signal(self.comp_signal[dev], 0)
dev._set_signal(self.copy_signal[dev], 0)
dev._set_signal(self.kickoff_signal, self.kickoff_value)
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items():
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
# Update var_vals
for j in self.jc_idx_with_updatable_var_vals:
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars):
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
for dev in self.devices:
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
for entry in self.queue_list:
if isinstance(entry[0], self.comp_hcq_t) or isinstance(entry[0], self.copy_hcq_t): queue, dev = entry
else:
# Kernel with dynamic launch bounds, rebuild it.
j, ji, deps, dev = entry[0], self.jit_cache[entry[0]], entry[1], self.jit_cache[entry[0]].prg.device
queue = self.comp_hcq_t()
for sig, val in deps: queue.wait(sig, val)
queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.launch_dims(var_vals)) \
.signal(self.comp_signal[dev], value=j+1)
queue.submit(dev)
for dev in self.devices:
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
self.graph_timeline[dev] = dev.timeline_value
dev.timeline_value += 1
if wait:
st = time.perf_counter()
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
return time.perf_counter() - st
return None
def access_resources(self, read, write, new_dependency):
deps = self._access_resources(read, write, new_dependency)
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]

View File

@@ -388,7 +388,7 @@ class AMDAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions):
try:
if options.host: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True)
else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=True)
else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access)
except OSError as e:
if e.errno == errno.ENOMEM: raise MemoryError("Cannot allocate memory") from e
else: raise
@@ -584,7 +584,9 @@ class AMDDevice(Compiled):
self.pm4_write_pointer = to_mv(self.pm4_queue.write_pointer_address, 8).cast("Q")
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self))
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
def synchronize(self):
AMDDevice._wait_signal(self.timeline_signal, self.timeline_value - 1)

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array
from typing import Tuple, List, Any, cast
from dataclasses import replace
from tinygrad.device import Compiled, Compiler, CompilerOptions
@@ -116,7 +116,7 @@ class HWComputeQueue:
def submit(self, dev:NVDevice):
if len(self.q) == 0: return
assert len(self.q) < (1 << 21)
for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
fifo_entry = dev.compute_put_value % dev.compute_gpfifo_entries
dev.compute_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42) | (1 << 41)
dev.compute_gpu_ring_controls.GPPut = (dev.compute_put_value + 1) % dev.compute_gpfifo_entries
@@ -146,7 +146,7 @@ class HWCopyQueue:
def submit(self, dev:NVDevice):
if len(self.q) == 0: return
for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
fifo_entry = dev.dma_put_value % dev.dma_gpfifo_entries
dev.dma_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42)
dev.dma_gpu_ring_controls.GPPut = (dev.dma_put_value + 1) % dev.dma_gpfifo_entries
@@ -235,6 +235,9 @@ class NVProgram:
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
if not hasattr(self, "args_struct_t"):
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
if self.device.kernargs_ptr >= (self.device.kernargs_page.base + self.device.kernargs_page.length - self.kernargs_segment_size):
self.device.kernargs_ptr = self.device.kernargs_page.base
@@ -265,7 +268,7 @@ class NVAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions):
if options.host: return self.device._gpu_host_alloc(size)
else: return self.device._gpu_alloc(size)
else: return self.device._gpu_alloc(size, map_to_cpu=options.cpu_access)
def _free(self, gpumem, options:BufferOptions):
NVDevice.synchronize_system()
@@ -491,7 +494,9 @@ class NVDevice(Compiled):
self.arch: str = 'sm_89' # TODO: fix
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self))
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self),
functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
self._cmdq_setup_compute_gpfifo()
self._cmdq_setup_dma_gpfifo()