more hcq typing [pr] (#7813)

* more hcq typing [pr]

* minor

* less generic
This commit is contained in:
George Hotz
2024-11-21 11:23:07 +08:00
committed by GitHub
parent 9df5a62c5e
commit 490a6130af
3 changed files with 44 additions and 39 deletions

View File

@@ -209,28 +209,30 @@ class AMDCopyQueue(HWQueue): # pylint: disable=abstract-method
if src is not None: self._patch(cmd_idx, offset=3+i*7, data=[*data64_le(src + SDMA_MAX_COPY_SIZE*i)])
if dest is not None: self._patch(cmd_idx, offset=5+i*7, data=[*data64_le(dest + SDMA_MAX_COPY_SIZE*i)])
def _signal(self, signal, value=0):
def _signal(self, signal:AMDSignal, value=0):
self._q([amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal._value_addr), value])
if signal._event_mailbox_ptr != 0:
self._q([amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal._event_mailbox_ptr), signal._event.event_id])
self._q([amd_gpu.SDMA_OP_TRAP, amd_gpu.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(signal._event.event_id)])
def _wait(self, signal, value=0):
def _wait(self, signal:AMDSignal, value=0):
self._q([amd_gpu.SDMA_OP_POLL_REGMEM | amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \
amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(signal._value_addr), value, 0xffffffff,
amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff)])
def _update_signal(self, cmd_idx, signal=None, value=None): return self._update_wait(cmd_idx, signal, value) # the same offsets and commands
def _update_wait(self, cmd_idx, signal=None, value=None):
def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
return self._update_wait(cmd_idx, signal, value) # the same offsets and commands
def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
if signal is not None: self._patch(cmd_idx, offset=1, data=data64_le(signal._value_addr))
if value is not None: self._patch(cmd_idx, offset=3, data=[value])
def _timestamp(self, signal):
def _timestamp(self, signal:AMDSignal):
self._q([amd_gpu.SDMA_OP_TIMESTAMP | amd_gpu.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(amd_gpu.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL),
*data64_le(signal._timestamp_addr)])
def _submit(self, dev):
def _submit(self, dev:AMDDevice):
if dev.sdma_queue.put_value - dev.sdma_queue.read_ptr[0] > dev.sdma_queue.ring.nbytes: raise RuntimeError("SDMA queue overrun")
tail_blit_dword = 0
@@ -303,7 +305,7 @@ class AMDProgram(HCQProgram):
def __del__(self):
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True, nolru=True))
class AMDAllocator(HCQAllocator):
class AMDAllocator(HCQAllocator['AMDDevice']):
def __init__(self, dev:AMDDevice): super().__init__(dev, batch_size=SDMA_MAX_COPY_SIZE)
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:

View File

@@ -83,7 +83,7 @@ class NVSignal(HCQSignal):
def _get_timestamp(self) -> decimal.Decimal: return decimal.Decimal(self._signal[1]) / decimal.Decimal(1000)
def _set_value(self, new_value:int): self._signal[0] = new_value
class NVCommandQueue(HWQueue): # pylint: disable=abstract-method
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): # pylint: disable=abstract-method
def __del__(self):
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True))
@@ -132,7 +132,7 @@ class NVCommandQueue(HWQueue): # pylint: disable=abstract-method
dev.gpu_mmio[0x90 // 4] = gpfifo.token
gpfifo.put_value += 1
class NVComputeQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-method
class NVComputeQueue(NVCommandQueue): # pylint: disable=abstract-method
def __init__(self):
self.cmd_idx_to_qmd, self.cmd_idx_to_signal_id, self.cmd_idx_to_global_dims, self.cmd_idx_to_local_dims = {}, {}, {}, {}
super().__init__()
@@ -187,7 +187,7 @@ class NVComputeQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-meth
def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).compute_gpfifo)
class NVCopyQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-method
class NVCopyQueue(NVCommandQueue): # pylint: disable=abstract-method
def _copy(self, dest, src, copy_size):
self.q += [nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest)]
self.q += [nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size]
@@ -290,7 +290,7 @@ class NVProgram(HCQProgram):
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
class NVAllocator(HCQAllocator):
class NVAllocator(HCQAllocator['NVDevice']):
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
if options.host: return self.dev._gpu_host_alloc(size, tag="user host memory")
return self.dev._gpu_alloc(size, map_to_cpu=options.cpu_access, huge_page=(size > (16 << 20)), tag=f"user memory ({options})")
@@ -310,7 +310,7 @@ class GPFifo:
put_value: int = 0
MAP_FIXED, MAP_NORESERVE = 0x10, 0x400
class NVDevice(HCQCompiled):
class NVDevice(HCQCompiled[NVSignal]):
root = None
fd_ctl: int = -1
fd_uvm: int = -1
@@ -534,9 +534,9 @@ class NVDevice(HCQCompiled):
NVComputeQueue().setup(compute_class=self.compute_class, local_mem_window=self.local_mem_window, shared_mem_window=self.shared_mem_window) \
.signal(self.timeline_signal, self.timeline_value).submit(self)
NVCopyQueue().wait(self.timeline_signal, self.timeline_value) \
.setup(copy_class=nv_gpu.AMPERE_DMA_COPY_B) \
.signal(self.timeline_signal, self.timeline_value + 1).submit(self)
cast(NVCopyQueue, NVCopyQueue().wait(self.timeline_signal, self.timeline_value)) \
.setup(copy_class=nv_gpu.AMPERE_DMA_COPY_B) \
.signal(self.timeline_signal, self.timeline_value + 1).submit(self)
self.timeline_value += 2
@@ -555,9 +555,9 @@ class NVDevice(HCQCompiled):
bytes_per_tpc = round_up(round_up(self.slm_per_thread * 32, 0x200) * self.max_warps_per_sm * self.num_sm_per_tpc, 0x8000)
self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * self.num_tpc_per_gpc * self.num_gpcs, 0x20000))
NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1) \
.setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
.signal(self.timeline_signal, self.timeline_value).submit(self)
cast(NVComputeQueue, NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1)) \
.setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
.signal(self.timeline_signal, self.timeline_value).submit(self)
self.timeline_value += 1
def invalidate_caches(self):

View File

@@ -1,13 +1,20 @@
from __future__ import annotations
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, Any
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, ParamSpec, Concatenate
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
from tinygrad.renderer import Renderer
from tinygrad.device import BufferOptions, Compiler, Compiled, LRUAllocator
# **************** for HCQ Compatible Devices ****************
def hcq_command(func: Callable[..., None]) -> Callable[..., Any]:
SignalType = TypeVar('SignalType', bound='HCQSignal')
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
QueueType = TypeVar('QueueType', bound='HWQueue')
P = ParamSpec('P')
def hcq_command(func: Callable[Concatenate[QueueType, P], None]) -> Callable[Concatenate[QueueType, P], QueueType]:
"""
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
@@ -17,7 +24,8 @@ def hcq_command(func: Callable[..., None]) -> Callable[..., Any]:
def command_method(self, ...): ...
```
"""
def __wrapper(self:HWQueue, *args, **kwargs):
@functools.wraps(func)
def __wrapper(self:QueueType, *args:P.args, **kwargs:P.kwargs) -> QueueType:
self.cmds_offset.append(len(self.q))
func(self, *args, **kwargs)
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
@@ -25,11 +33,6 @@ def hcq_command(func: Callable[..., None]) -> Callable[..., Any]:
return self
return __wrapper
SignalType = TypeVar('SignalType', bound='HCQSignal')
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
"""
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
@@ -178,7 +181,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
# *** commands for copy queues ***
@hcq_command
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
def copy(self, dest:int, src:int, copy_size:int):
"""
Enqueues a copy command to transfer data. Only on copy queues.
@@ -188,9 +191,9 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
copy_size: The size of data to copy
"""
self._copy(dest, src, copy_size)
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
def _copy(self, dest:int, src:int, copy_size:int): raise NotImplementedError("backend should overload this function")
def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
def update_copy(self, cmd_idx:int, dest:Optional[int]=None, src:Optional[int]=None):
"""
Updates a previously queued copy command. Only on copy queues.
@@ -202,7 +205,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
self._update_copy(cmd_idx, dest, src)
return self
def _update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer], src:Optional[HCQBuffer]):
def _update_copy(self, cmd_idx:int, dest:Optional[int], src:Optional[int]):
raise NotImplementedError("backend should overload this function")
class HCQSignal:
@@ -348,7 +351,7 @@ class ProfileLogger:
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
class HCQCompiled(Compiled):
class HCQCompiled(Compiled, Generic[SignalType]):
"""
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
"""
@@ -356,7 +359,7 @@ class HCQCompiled(Compiled):
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]):
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
self.timeline_value:int = 1
@@ -494,8 +497,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
self.b_timeline[self.b_next] = self.dev.timeline_value
self.dev.timeline_value += 1
@@ -511,8 +514,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE):
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
self.b_timeline[batch_info[1]] = self.dev.timeline_value
self.dev.timeline_value += 1
@@ -523,8 +526,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
for i in range(0, dest.nbytes, self.b[0].size):
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
self.dev.timeline_signal.wait(self.dev.timeline_value)
self.dev.timeline_value += 1