mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
more hcq typing [pr] (#7813)
* more hcq typing [pr] * minor * less generic
This commit is contained in:
@@ -209,28 +209,30 @@ class AMDCopyQueue(HWQueue): # pylint: disable=abstract-method
|
||||
if src is not None: self._patch(cmd_idx, offset=3+i*7, data=[*data64_le(src + SDMA_MAX_COPY_SIZE*i)])
|
||||
if dest is not None: self._patch(cmd_idx, offset=5+i*7, data=[*data64_le(dest + SDMA_MAX_COPY_SIZE*i)])
|
||||
|
||||
def _signal(self, signal, value=0):
|
||||
def _signal(self, signal:AMDSignal, value=0):
|
||||
self._q([amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal._value_addr), value])
|
||||
|
||||
if signal._event_mailbox_ptr != 0:
|
||||
self._q([amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal._event_mailbox_ptr), signal._event.event_id])
|
||||
self._q([amd_gpu.SDMA_OP_TRAP, amd_gpu.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(signal._event.event_id)])
|
||||
|
||||
def _wait(self, signal, value=0):
|
||||
def _wait(self, signal:AMDSignal, value=0):
|
||||
self._q([amd_gpu.SDMA_OP_POLL_REGMEM | amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \
|
||||
amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(signal._value_addr), value, 0xffffffff,
|
||||
amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff)])
|
||||
|
||||
def _update_signal(self, cmd_idx, signal=None, value=None): return self._update_wait(cmd_idx, signal, value) # the same offsets and commands
|
||||
def _update_wait(self, cmd_idx, signal=None, value=None):
|
||||
def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
|
||||
return self._update_wait(cmd_idx, signal, value) # the same offsets and commands
|
||||
|
||||
def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
|
||||
if signal is not None: self._patch(cmd_idx, offset=1, data=data64_le(signal._value_addr))
|
||||
if value is not None: self._patch(cmd_idx, offset=3, data=[value])
|
||||
|
||||
def _timestamp(self, signal):
|
||||
def _timestamp(self, signal:AMDSignal):
|
||||
self._q([amd_gpu.SDMA_OP_TIMESTAMP | amd_gpu.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(amd_gpu.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL),
|
||||
*data64_le(signal._timestamp_addr)])
|
||||
|
||||
def _submit(self, dev):
|
||||
def _submit(self, dev:AMDDevice):
|
||||
if dev.sdma_queue.put_value - dev.sdma_queue.read_ptr[0] > dev.sdma_queue.ring.nbytes: raise RuntimeError("SDMA queue overrun")
|
||||
|
||||
tail_blit_dword = 0
|
||||
@@ -303,7 +305,7 @@ class AMDProgram(HCQProgram):
|
||||
def __del__(self):
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True, nolru=True))
|
||||
|
||||
class AMDAllocator(HCQAllocator):
|
||||
class AMDAllocator(HCQAllocator['AMDDevice']):
|
||||
def __init__(self, dev:AMDDevice): super().__init__(dev, batch_size=SDMA_MAX_COPY_SIZE)
|
||||
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
|
||||
|
||||
@@ -83,7 +83,7 @@ class NVSignal(HCQSignal):
|
||||
def _get_timestamp(self) -> decimal.Decimal: return decimal.Decimal(self._signal[1]) / decimal.Decimal(1000)
|
||||
def _set_value(self, new_value:int): self._signal[0] = new_value
|
||||
|
||||
class NVCommandQueue(HWQueue): # pylint: disable=abstract-method
|
||||
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): # pylint: disable=abstract-method
|
||||
def __del__(self):
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferOptions(cpu_access=True, nolru=True))
|
||||
|
||||
@@ -132,7 +132,7 @@ class NVCommandQueue(HWQueue): # pylint: disable=abstract-method
|
||||
dev.gpu_mmio[0x90 // 4] = gpfifo.token
|
||||
gpfifo.put_value += 1
|
||||
|
||||
class NVComputeQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-method
|
||||
class NVComputeQueue(NVCommandQueue): # pylint: disable=abstract-method
|
||||
def __init__(self):
|
||||
self.cmd_idx_to_qmd, self.cmd_idx_to_signal_id, self.cmd_idx_to_global_dims, self.cmd_idx_to_local_dims = {}, {}, {}, {}
|
||||
super().__init__()
|
||||
@@ -187,7 +187,7 @@ class NVComputeQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-meth
|
||||
|
||||
def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).compute_gpfifo)
|
||||
|
||||
class NVCopyQueue(NVCommandQueue, HWQueue): # pylint: disable=abstract-method
|
||||
class NVCopyQueue(NVCommandQueue): # pylint: disable=abstract-method
|
||||
def _copy(self, dest, src, copy_size):
|
||||
self.q += [nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest)]
|
||||
self.q += [nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size]
|
||||
@@ -290,7 +290,7 @@ class NVProgram(HCQProgram):
|
||||
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
|
||||
return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
|
||||
|
||||
class NVAllocator(HCQAllocator):
|
||||
class NVAllocator(HCQAllocator['NVDevice']):
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
|
||||
if options.host: return self.dev._gpu_host_alloc(size, tag="user host memory")
|
||||
return self.dev._gpu_alloc(size, map_to_cpu=options.cpu_access, huge_page=(size > (16 << 20)), tag=f"user memory ({options})")
|
||||
@@ -310,7 +310,7 @@ class GPFifo:
|
||||
put_value: int = 0
|
||||
|
||||
MAP_FIXED, MAP_NORESERVE = 0x10, 0x400
|
||||
class NVDevice(HCQCompiled):
|
||||
class NVDevice(HCQCompiled[NVSignal]):
|
||||
root = None
|
||||
fd_ctl: int = -1
|
||||
fd_uvm: int = -1
|
||||
@@ -534,9 +534,9 @@ class NVDevice(HCQCompiled):
|
||||
NVComputeQueue().setup(compute_class=self.compute_class, local_mem_window=self.local_mem_window, shared_mem_window=self.shared_mem_window) \
|
||||
.signal(self.timeline_signal, self.timeline_value).submit(self)
|
||||
|
||||
NVCopyQueue().wait(self.timeline_signal, self.timeline_value) \
|
||||
.setup(copy_class=nv_gpu.AMPERE_DMA_COPY_B) \
|
||||
.signal(self.timeline_signal, self.timeline_value + 1).submit(self)
|
||||
cast(NVCopyQueue, NVCopyQueue().wait(self.timeline_signal, self.timeline_value)) \
|
||||
.setup(copy_class=nv_gpu.AMPERE_DMA_COPY_B) \
|
||||
.signal(self.timeline_signal, self.timeline_value + 1).submit(self)
|
||||
|
||||
self.timeline_value += 2
|
||||
|
||||
@@ -555,9 +555,9 @@ class NVDevice(HCQCompiled):
|
||||
bytes_per_tpc = round_up(round_up(self.slm_per_thread * 32, 0x200) * self.max_warps_per_sm * self.num_sm_per_tpc, 0x8000)
|
||||
self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * self.num_tpc_per_gpc * self.num_gpcs, 0x20000))
|
||||
|
||||
NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1) \
|
||||
.setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
|
||||
.signal(self.timeline_signal, self.timeline_value).submit(self)
|
||||
cast(NVComputeQueue, NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1)) \
|
||||
.setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
|
||||
.signal(self.timeline_signal, self.timeline_value).submit(self)
|
||||
self.timeline_value += 1
|
||||
|
||||
def invalidate_caches(self):
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, Any
|
||||
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes
|
||||
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, ParamSpec, Concatenate
|
||||
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools
|
||||
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferOptions, Compiler, Compiled, LRUAllocator
|
||||
|
||||
# **************** for HCQ Compatible Devices ****************
|
||||
|
||||
def hcq_command(func: Callable[..., None]) -> Callable[..., Any]:
|
||||
SignalType = TypeVar('SignalType', bound='HCQSignal')
|
||||
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
|
||||
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
||||
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
||||
QueueType = TypeVar('QueueType', bound='HWQueue')
|
||||
|
||||
P = ParamSpec('P')
|
||||
def hcq_command(func: Callable[Concatenate[QueueType, P], None]) -> Callable[Concatenate[QueueType, P], QueueType]:
|
||||
"""
|
||||
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
|
||||
|
||||
@@ -17,7 +24,8 @@ def hcq_command(func: Callable[..., None]) -> Callable[..., Any]:
|
||||
def command_method(self, ...): ...
|
||||
```
|
||||
"""
|
||||
def __wrapper(self:HWQueue, *args, **kwargs):
|
||||
@functools.wraps(func)
|
||||
def __wrapper(self:QueueType, *args:P.args, **kwargs:P.kwargs) -> QueueType:
|
||||
self.cmds_offset.append(len(self.q))
|
||||
func(self, *args, **kwargs)
|
||||
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
|
||||
@@ -25,11 +33,6 @@ def hcq_command(func: Callable[..., None]) -> Callable[..., Any]:
|
||||
return self
|
||||
return __wrapper
|
||||
|
||||
SignalType = TypeVar('SignalType', bound='HCQSignal')
|
||||
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
|
||||
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
||||
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
||||
|
||||
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
"""
|
||||
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
||||
@@ -178,7 +181,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
# *** commands for copy queues ***
|
||||
|
||||
@hcq_command
|
||||
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
|
||||
def copy(self, dest:int, src:int, copy_size:int):
|
||||
"""
|
||||
Enqueues a copy command to transfer data. Only on copy queues.
|
||||
|
||||
@@ -188,9 +191,9 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
copy_size: The size of data to copy
|
||||
"""
|
||||
self._copy(dest, src, copy_size)
|
||||
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
|
||||
def _copy(self, dest:int, src:int, copy_size:int): raise NotImplementedError("backend should overload this function")
|
||||
|
||||
def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
|
||||
def update_copy(self, cmd_idx:int, dest:Optional[int]=None, src:Optional[int]=None):
|
||||
"""
|
||||
Updates a previously queued copy command. Only on copy queues.
|
||||
|
||||
@@ -202,7 +205,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
|
||||
self._update_copy(cmd_idx, dest, src)
|
||||
return self
|
||||
def _update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer], src:Optional[HCQBuffer]):
|
||||
def _update_copy(self, cmd_idx:int, dest:Optional[int], src:Optional[int]):
|
||||
raise NotImplementedError("backend should overload this function")
|
||||
|
||||
class HCQSignal:
|
||||
@@ -348,7 +351,7 @@ class ProfileLogger:
|
||||
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
||||
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
||||
|
||||
class HCQCompiled(Compiled):
|
||||
class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
"""
|
||||
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
||||
"""
|
||||
@@ -356,7 +359,7 @@ class HCQCompiled(Compiled):
|
||||
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
||||
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
||||
|
||||
def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
|
||||
def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
||||
comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]):
|
||||
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
||||
self.timeline_value:int = 1
|
||||
@@ -494,8 +497,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
|
||||
self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
|
||||
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
||||
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
||||
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
||||
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
||||
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
||||
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
||||
self.b_timeline[self.b_next] = self.dev.timeline_value
|
||||
self.dev.timeline_value += 1
|
||||
|
||||
@@ -511,8 +514,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
|
||||
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE):
|
||||
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
||||
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
||||
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
||||
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
||||
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
||||
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
||||
self.b_timeline[batch_info[1]] = self.dev.timeline_value
|
||||
self.dev.timeline_value += 1
|
||||
|
||||
@@ -523,8 +526,8 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]): # pylint: disable=abstrac
|
||||
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
|
||||
for i in range(0, dest.nbytes, self.b[0].size):
|
||||
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
||||
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
||||
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
||||
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
||||
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
||||
self.dev.timeline_signal.wait(self.dev.timeline_value)
|
||||
self.dev.timeline_value += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user