From 490a6130afc262dd60e656a2f7eec8e58ec76528 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:23:07 +0800 Subject: [PATCH] more hcq typing [pr] (#7813) * more hcq typing [pr] * minor * less generic --- tinygrad/runtime/ops_amd.py | 16 +++++++----- tinygrad/runtime/ops_nv.py | 22 ++++++++-------- tinygrad/runtime/support/hcq.py | 45 ++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 39 deletions(-) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 4d2297a2c0..3e7df8a73d 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -209,28 +209,30 @@ class AMDCopyQueue(HWQueue): # pylint: disable=abstract-method if src is not None: self._patch(cmd_idx, offset=3+i*7, data=[*data64_le(src + SDMA_MAX_COPY_SIZE*i)]) if dest is not None: self._patch(cmd_idx, offset=5+i*7, data=[*data64_le(dest + SDMA_MAX_COPY_SIZE*i)]) - def _signal(self, signal, value=0): + def _signal(self, signal:AMDSignal, value=0): self._q([amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal._value_addr), value]) if signal._event_mailbox_ptr != 0: self._q([amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal._event_mailbox_ptr), signal._event.event_id]) self._q([amd_gpu.SDMA_OP_TRAP, amd_gpu.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(signal._event.event_id)]) - def _wait(self, signal, value=0): + def _wait(self, signal:AMDSignal, value=0): self._q([amd_gpu.SDMA_OP_POLL_REGMEM | amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \ amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(signal._value_addr), value, 0xffffffff, amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff)]) - def _update_signal(self, cmd_idx, signal=None, value=None): return self._update_wait(cmd_idx, signal, value) # the same offsets and commands - def _update_wait(self, cmd_idx, signal=None, value=None): + def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None): + return self._update_wait(cmd_idx, signal, value) # the same offsets and commands + + def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None): if signal is not None: self._patch(cmd_idx, offset=1, data=data64_le(signal._value_addr)) if value is not None: self._patch(cmd_idx, offset=3, data=[value]) - def _timestamp(self, signal): + def _timestamp(self, signal:AMDSignal): self._q([amd_gpu.SDMA_OP_TIMESTAMP | amd_gpu.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(amd_gpu.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL), *data64_le(signal._timestamp_addr)]) - def _submit(self, dev): + def _submit(self, dev:AMDDevice): if dev.sdma_queue.put_value - dev.sdma_queue.read_ptr[0] > dev.sdma_queue.ring.nbytes: raise RuntimeError("SDMA queue overrun") tail_blit_dword = 0 @@ -303,7 +305,7 @@ class AMDProgram(HCQProgram): def __del__(self): if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True, nolru=True)) -class AMDAllocator(HCQAllocator): +class AMDAllocator(HCQAllocator['AMDDevice']): def __init__(self, dev:AMDDevice): super().__init__(dev, batch_size=SDMA_MAX_COPY_SIZE) def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index 3037f5e4e3..48688ec09f 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -83,7 +83,7 @@ class NVSignal(HCQSignal): def _get_timestamp(self) -> decimal.Decimal: return decimal.Decimal(self._signal[1]) / decimal.Decimal(1000) def _set_value(self, new_value:int): self._signal[0] = new_value -class NVCommandQueue(HWQueue): # pylint: disable=abstract-method +class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): # pylint: disable=abstract-method def __del__(self): if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True)) @@ -132,7 +132,7 @@ class NVCommandQueue(HWQueue): # pylint: disable=abstract-method dev.gpu_mmio[0x90 // 4] = gpfifo.token gpfifo.put_value += 1 -class NVComputeQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-method +class NVComputeQueue(NVCommandQueue): # pylint: disable=abstract-method def __init__(self): self.cmd_idx_to_qmd, self.cmd_idx_to_signal_id, self.cmd_idx_to_global_dims, self.cmd_idx_to_local_dims = {}, {}, {}, {} super().__init__() @@ -187,7 +187,7 @@ class NVComputeQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-meth def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).compute_gpfifo) -class NVCopyQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-method +class NVCopyQueue(NVCommandQueue): # pylint: disable=abstract-method def _copy(self, dest, src, copy_size): self.q += [nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest)] self.q += [nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size] @@ -290,7 +290,7 @@ class NVProgram(HCQProgram): raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}") return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait) -class NVAllocator(HCQAllocator): +class NVAllocator(HCQAllocator['NVDevice']): def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: if options.host: return self.dev._gpu_host_alloc(size, tag="user host memory") return self.dev._gpu_alloc(size, map_to_cpu=options.cpu_access, huge_page=(size > (16 << 20)), tag=f"user memory ({options})") @@ -310,7 +310,7 @@ class GPFifo: put_value: int = 0 MAP_FIXED, MAP_NORESERVE = 0x10, 0x400 -class NVDevice(HCQCompiled): +class NVDevice(HCQCompiled[NVSignal]): root = None fd_ctl: int = -1 fd_uvm: int = -1 @@ -534,9 +534,9 @@ class NVDevice(HCQCompiled): NVComputeQueue().setup(compute_class=self.compute_class, local_mem_window=self.local_mem_window, shared_mem_window=self.shared_mem_window) \ .signal(self.timeline_signal, self.timeline_value).submit(self) - NVCopyQueue().wait(self.timeline_signal, self.timeline_value) \ - .setup(copy_class=nv_gpu.AMPERE_DMA_COPY_B) \ - .signal(self.timeline_signal, self.timeline_value + 1).submit(self) + cast(NVCopyQueue, NVCopyQueue().wait(self.timeline_signal, self.timeline_value)) \ + .setup(copy_class=nv_gpu.AMPERE_DMA_COPY_B) \ + .signal(self.timeline_signal, self.timeline_value + 1).submit(self) self.timeline_value += 2 @@ -555,9 +555,9 @@ class NVDevice(HCQCompiled): bytes_per_tpc = round_up(round_up(self.slm_per_thread * 32, 0x200) * self.max_warps_per_sm * self.num_sm_per_tpc, 0x8000) self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * self.num_tpc_per_gpc * self.num_gpcs, 0x20000)) - NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1) \ - .setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \ - .signal(self.timeline_signal, self.timeline_value).submit(self) + cast(NVComputeQueue, NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1)) \ + .setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \ + .signal(self.timeline_signal, self.timeline_value).submit(self) self.timeline_value += 1 def invalidate_caches(self): diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index 5514849115..3a18c18988 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -1,13 +1,20 @@ from __future__ import annotations -from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, Any -import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes +from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, ParamSpec, Concatenate +import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv from tinygrad.renderer import Renderer from tinygrad.device import BufferOptions, Compiler, Compiled, LRUAllocator # **************** for HCQ Compatible Devices **************** -def hcq_command(func: Callable[..., None]) -> Callable[..., Any]: +SignalType = TypeVar('SignalType', bound='HCQSignal') +DeviceType = TypeVar('DeviceType', bound='HCQCompiled') +ProgramType = TypeVar('ProgramType', bound='HCQProgram') +ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState') +QueueType = TypeVar('QueueType', bound='HWQueue') + +P = ParamSpec('P') +def hcq_command(func: Callable[Concatenate[QueueType, P], None]) -> Callable[Concatenate[QueueType, P], QueueType]: """ Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates. @@ -17,7 +24,8 @@ def hcq_command(func: Callable[..., None]) -> Callable[..., Any]: def command_method(self, ...): ... ``` """ - def __wrapper(self:HWQueue, *args, **kwargs): + @functools.wraps(func) + def __wrapper(self:QueueType, *args:P.args, **kwargs:P.kwargs) -> QueueType: self.cmds_offset.append(len(self.q)) func(self, *args, **kwargs) self.cmds_len.append(len(self.q) - self.cmds_offset[-1]) @@ -25,11 +33,6 @@ def hcq_command(func: Callable[..., None]) -> Callable[..., Any]: return self return __wrapper -SignalType = TypeVar('SignalType', bound='HCQSignal') -DeviceType = TypeVar('DeviceType', bound='HCQCompiled') -ProgramType = TypeVar('ProgramType', bound='HCQProgram') -ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState') - class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): """ A base class for hardware command queues in the HCQ (Hardware Command Queue) API. @@ -178,7 +181,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): # *** commands for copy queues *** @hcq_command - def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): + def copy(self, dest:int, src:int, copy_size:int): """ Enqueues a copy command to transfer data. Only on copy queues. @@ -188,9 +191,9 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): copy_size: The size of data to copy """ self._copy(dest, src, copy_size) - def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function") + def _copy(self, dest:int, src:int, copy_size:int): raise NotImplementedError("backend should overload this function") - def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None): + def update_copy(self, cmd_idx:int, dest:Optional[int]=None, src:Optional[int]=None): """ Updates a previously queued copy command. Only on copy queues. @@ -202,7 +205,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command") self._update_copy(cmd_idx, dest, src) return self - def _update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer], src:Optional[HCQBuffer]): + def _update_copy(self, cmd_idx:int, dest:Optional[int], src:Optional[int]): raise NotImplementedError("backend should overload this function") class HCQSignal: @@ -348,7 +351,7 @@ class ProfileLogger: with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson})) print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.") -class HCQCompiled(Compiled): +class HCQCompiled(Compiled, Generic[SignalType]): """ A base class for devices compatible with the HCQ (Hardware Command Queue) API. """ @@ -356,7 +359,7 @@ class HCQCompiled(Compiled): gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan') gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan') - def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal], + def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType], comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]): self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t self.timeline_value:int = 1 @@ -494,8 +497,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac self.dev.timeline_signal.wait(self.b_timeline[self.b_next]) ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i)) self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \ - .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \ - .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev) + .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \ + .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev) self.b_timeline[self.b_next] = self.dev.timeline_value self.dev.timeline_value += 1 @@ -511,8 +514,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE): for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size): self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \ - .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \ - .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev) + .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \ + .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev) self.b_timeline[batch_info[1]] = self.dev.timeline_value self.dev.timeline_value += 1 @@ -523,8 +526,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE): for i in range(0, dest.nbytes, self.b[0].size): self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \ - .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \ - .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev) + .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \ + .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev) self.dev.timeline_signal.wait(self.dev.timeline_value) self.dev.timeline_value += 1