diff --git a/docs/runtime.md b/docs/runtime.md index ac2d0e7d60..257afab1a4 100644 --- a/docs/runtime.md +++ b/docs/runtime.md @@ -5,7 +5,7 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra | Runtime | Description | Compiler Options | Requirements | |---------|-------------|------------------|--------------| | [NV](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_nv.py) | Provides acceleration for NVIDIA GPUs | nvrtc (default)
PTX (`DEV=NV:PTX`) | Ampere/Ada/Blackwell series GPUs.
You can select an interface via [the `DEV` variable](env_vars.md#dev-variable). See [NV interfaces](#nv-interfaces) for details. | -| [AMD](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_amd.py) | Provides acceleration for AMD GPUs | LLVM (`DEV=AMD:LLVM`)
HIP/COMGR (`DEV=AMD:HIP`) | RDNA2 or newer GPUs.
You can select an interface via [the `DEV` variable](env_vars.md#dev-variable). See [AMD interfaces](#amd-interfaces) for details. | +| [AMD](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_amd.py) | Provides acceleration for AMD GPUs | LLVM (`DEV=AMD:LLVM`)
HIP/COMGR (`DEV=AMD:HIP`) | CDNA3, CDNA4, RDNA3 or RDNA4 GPUs.
You can select an interface via [the `DEV` variable](env_vars.md#dev-variable). See [AMD interfaces](#amd-interfaces) for details. | | [QCOM](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_qcom.py) | Provides acceleration for QCOM GPUs | - | 6xx series GPUs | | [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | - | M1+ Macs; Metal 3.0+ for `bfloat` support | | [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | nvrtc (default)
PTX (`DEV=CUDA:PTX`) | NVIDIA GPU with CUDA support | diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 36906b4379..d29e4f28b8 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -15,7 +15,7 @@ from tinygrad.runtime.autogen import kfd, hsa, sqtt, amdgpu_kd, amdgpu_drm from tinygrad.runtime.autogen.am import am from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager -from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_soc, import_ip_offsets, import_pmc +from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_soc, import_pmc from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta, USBPCIDevice, MAP_FIXED, MAP_NORESERVE from tinygrad.runtime.support.usb import USB3 from tinygrad.runtime.support.memory import AddrSpace @@ -90,7 +90,7 @@ class AMDComputeQueue(HWQueue): return self def acquire_mem(self, addr=0x0, sz=(1 << 64)-1, gli=1, glm=1, glk=1, glv=1, gl1=1, gl2=1): - if self.dev.target >= (10,0,0): + if self.dev.target[0] != 9: cache_flags_dw = self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLI_INV(gli) \ | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_INV(glm) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLM_WB(glm) \ | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_INV(glk) | self.pm4.PACKET3_ACQUIRE_MEM_GCR_CNTL_GLK_WB(glk) \ @@ -108,7 +108,7 @@ class AMDComputeQueue(HWQueue): return self def release_mem(self, address=0x0, value=0, data_sel=0, int_sel=2, ctxid=0, cache_flush=False): - if self.dev.target >= (10,0,0): + if self.dev.target[0] != 9: cache_flags_dw = 0 if not cache_flush else (self.pm4.PACKET3_RELEASE_MEM_GCR_GLV_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL1_INV \ | self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_WB \ | self.pm4.PACKET3_RELEASE_MEM_GCR_GLM_INV | self.pm4.PACKET3_RELEASE_MEM_GCR_GL2_WB | self.pm4.PACKET3_RELEASE_MEM_GCR_SEQ) @@ -151,7 +151,7 @@ class AMDComputeQueue(HWQueue): def pmc_start(self, counters): self.pmc_reset_counters(en=False) self.wreg(self.gc.regSQ_PERFCOUNTER_CTRL, cs_en=1, ps_en=1, gs_en=1, hs_en=1, **({'vmid_mask':0xffff} if (gfx9:=self.dev.target[0] == 9) else {})) - if self.dev.target[0] >= 11: self.wreg(self.gc.regSQ_PERFCOUNTER_CTRL2, force_en=1, vmid_en=0xffff) + if not gfx9: self.wreg(self.gc.regSQ_PERFCOUNTER_CTRL2, force_en=1, vmid_en=0xffff) end_off = 0 block2pid:dict[str, itertools.count] = collections.defaultdict(lambda: itertools.count()) @@ -163,7 +163,7 @@ class AMDComputeQueue(HWQueue): # gfx11+ and later require even-numbered SQ *_SELECT registers regsample = f'reg{block}_PERFCOUNTER{(pcid:=next(block2pid[block]))}' - if (regsel:=getattr(self.gc, (f'reg{block}_PERFCOUNTER{(pcid*2) if self.dev.target[0]>=11 and block=="SQ" else pcid}_SELECT'), None)) is None: + if (regsel:=getattr(self.gc, (f'reg{block}_PERFCOUNTER{(pcid*2) if not gfx9 and block=="SQ" else pcid}_SELECT'), None)) is None: raise RuntimeError(f'{block} is out of perfcounter registers: ({regsample} is not found)') self.wreg(regsel, perf_sel=idx, **({'simd_mask':0xf, 'sqc_bank_mask':0xf, 'sqc_client_mask':0xf} if gfx9 and block == "SQ" else {})) @@ -206,7 +206,7 @@ class AMDComputeQueue(HWQueue): cu_per_se = prod([x if isinstance(x, int) else 1 for x in global_size]) // ((self.dev.cu_cnt // self.dev.se_cnt) * 4) for xcc in range(self.dev.xccs): with self.pred_exec(xcc_mask=1 << xcc): - for i in range(8 if prg.dev.target >= (11,0,0) else 4): + for i in range(8 if prg.dev.target[0] != 9 else 4): if SQTT_LIMIT_SE > 1: mask = 1 if SQTT_ITRACE_SE_MASK.value & (1 << i) else 0 # only run unmasked shader engines else: sa_mask = (1 << (self.dev.iface.props['cu_per_simd_array'] // 2)) - 1 @@ -282,7 +282,7 @@ class AMDComputeQueue(HWQueue): self.sqtt_config(tracing=True) self.set_grbm() - if self.dev.target[0] > 9: self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 1) + if self.dev.target[0] != 9: self.wreg(self.gc.regCOMPUTE_THREAD_TRACE_ENABLE, 1) self.memory_barrier() return self @@ -303,7 +303,7 @@ class AMDComputeQueue(HWQueue): self.set_grbm(se=se % self.dev.se_cnt, sh=0) regstatus = self.gc.regSQ_THREAD_TRACE_STATUS.addr[0] - (self.pm4.PACKET3_SET_UCONFIG_REG_START if self.dev.target[0] == 9 else 0) - if self.dev.target >= (10,0,0): + if self.dev.target[0] != 9: self.wait_reg_mem(reg=regstatus, mask=self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('finish_pending'), op=WAIT_REG_MEM_FUNCTION_EQ, value=0) self.sqtt_config(tracing=False) self.wait_reg_mem(reg=regstatus, mask=self.gc.regSQ_THREAD_TRACE_STATUS.fields_mask('busy'), op=WAIT_REG_MEM_FUNCTION_EQ, value=0) @@ -313,7 +313,7 @@ class AMDComputeQueue(HWQueue): self.pkt3(self.pm4.PACKET3_COPY_DATA, 1 << 20 | 2 << 8 | 4, self.gc.regSQ_THREAD_TRACE_WPTR.addr[0], 0, *data64_le(wptrs.va_addr+(se*4))) self.set_grbm() - if self.dev.target[0] > 9: self.spi_config(tracing=False) + if self.dev.target[0] != 9: self.spi_config(tracing=False) self.memory_barrier() return self @@ -348,22 +348,20 @@ class AMDComputeQueue(HWQueue): self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, prg.rsrc3) self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, prg.dev.tmpring_size) - if prg.dev.has_scratch_base_registers: - for xcc_id in range(self.dev.xccs): - with self.pred_exec(xcc_mask=1<> 8)) - - if (10,0,0) <= prg.dev.target < (11,0,0): self.wreg(self.gc.mmCP_COHER_START_DELAY, 0x20) + # this is what llvm refers to as "architected flat scratch" + for xcc_id in range(self.dev.xccs): + with self.pred_exec(xcc_mask=1<> 8)) self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0) self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs) self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, waves_per_sh=getenv("WAVES_PER_SH")) self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0) - gfx10p = {'cs_w32_en': int(prg.wave32)} if prg.dev.target >= (10,0,0) else {} self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *global_size, - self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**gfx10p, force_start_at_000=1, compute_shader_en=1)) + self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**({'cs_w32_en': int(prg.wave32)} if prg.dev.target[0] != 9 else {}), + force_start_at_000=1, compute_shader_en=1)) if prg.dev.sqtt_enabled: self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.THREAD_TRACE_MARKER) | self.pm4.EVENT_INDEX(0)) self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH)) @@ -486,7 +484,7 @@ class AMDCopyQueue(HWQueue): return self def signal(self, signal:AMDSignal, value:sint=0): - fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target >= (10,0,0) else 0 + fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target[0] != 9 else 0 self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(signal.value_addr), value) if (dev:=signal.owner) is not None and signal.is_timeline and not dev.is_am(): @@ -593,7 +591,7 @@ class AMDProgram(HCQProgram): self.wave32: bool = desc.kernel_code_properties & 0x400 == 0x400 # Set rsrc1.priv=1 on gfx11 to workaround cwsr. - self.rsrc1: int = desc.compute_pgm_rsrc1 | ((1 << 20) if (11,0,0) <= self.dev.target < (12,0,0) else 0) + self.rsrc1: int = desc.compute_pgm_rsrc1 | ((1 << 20) if self.dev.target[0] == 11 else 0) self.rsrc2: int = desc.compute_pgm_rsrc2 | (lds_size << 15) self.rsrc3: int = desc.compute_pgm_rsrc3 self.aql_prog_addr: int = self.lib_gpu.va_addr + rodata_entry @@ -961,30 +959,26 @@ class AMDDevice(HCQCompiled): self.target:tuple[int, ...] = ((trgt:=self.iface.props['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100) self.arch = "gfx%d%x%x" % self.target - if self.target < (9,4,2) or self.target >= (13,0,0): raise RuntimeError(f"Unsupported arch: {self.arch}") + assert (self.target in ((9,4,2),(9,5,0))) or self.target[0] in (11, 12), f"Unsupported arch: {self.arch}" if DEBUG >= 1: print(f"AMDDevice: opening {self.device_id} with target {self.target} arch {self.arch}") self.xccs = self.iface.props.get('num_xcc', 1) self.se_cnt = self.iface.props['array_count'] // self.iface.props['simd_arrays_per_engine'] // self.xccs self.cu_cnt = self.iface.props['simd_count'] // self.iface.props['simd_per_cu'] // self.xccs self.waves_per_cu = self.iface.props['max_waves_per_simd'] * self.iface.props['simd_per_cu'] - self.wave_cnt = (self.cu_cnt * self.waves_per_cu) if self.target >= (10,1,0) else min(self.cu_cnt * 40, self.se_cnt * self.xccs * 512) - # this is what llvm refers to as "architected flat scratch" - self.has_scratch_base_registers = self.target >= (11,0,0) or self.target in {(9,4,2), (9,5,0)} + self.wave_cnt = (self.cu_cnt * self.waves_per_cu) if self.target[0] != 9 else min(self.cu_cnt * 40, self.se_cnt * self.xccs * 512) # https://gitlab.freedesktop.org/agd5f/linux/-/blob/a1fc9f584c4aaf8bc1ebfa459fc57a3f26a290d8/drivers/gpu/drm/amd/amdkfd/kfd_queue.c#L391 - sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000 - if self.target[:2] == (9,5): lds_size_per_cu = self.iface.props["lds_size_in_kb"] << 10 - vgpr_size_per_cu = 0x60000 if self.target in {(11,0,0), (11,0,1), (11,5,1), (12,0,0), (12,0,1)} else \ - 0x80000 if (self.target[:2]) in {(9,4), (9,5)} or self.target in {(9,0,8), (9,0,10)} else 0x40000 + sgrp_size_per_cu, hwreg_size_per_cu = 0x4000, 0x1000 + lds_size_per_cu = self.iface.props["lds_size_in_kb"] << 10 if self.target[:2] == (9,5) else 0x10000 + vgpr_size_per_cu = 0x60000 if self.target in {(11,0,0), (11,0,1), (11,5,1), (12,0,0), (12,0,1)} else 0x80000 if self.target[0] == 9 else 0x40000 wg_data_size = round_up((vgpr_size_per_cu + sgrp_size_per_cu + lds_size_per_cu + hwreg_size_per_cu) * self.cu_cnt, mmap.PAGESIZE) - ctl_stack_size = round_up((12 if self.target >= (10,1,0) else 8) * self.wave_cnt + 8 + 40, mmap.PAGESIZE) - if self.target[0] == 10: ctl_stack_size = min(ctl_stack_size, 0x7000) + ctl_stack_size = round_up((12 if self.target[0] != 9 else 8) * self.wave_cnt + 8 + 40, mmap.PAGESIZE) debug_memory_size = round_up(self.wave_cnt * 32, 64) - self.ip_off = import_ip_offsets(self.target) + self.ip_off = importlib.import_module(f"tinygrad.runtime.autogen.am.{'vega' if self.target[0] == 9 else 'navi'}_offsets") self.soc = import_soc(self.target) - self.pm4 = importlib.import_module(f"tinygrad.runtime.autogen.am.pm4_{'nv' if self.target[0] >= 10 else 'soc15'}") + self.pm4 = importlib.import_module(f"tinygrad.runtime.autogen.am.pm4_{'soc15' if self.target[0] == 9 else 'nv'}") self.sdma = import_module('sdma', min(self.iface.ip_versions[am.SDMA0_HWIP], (6, 0, 0))) self.gc = AMDIP('gc', self.iface.ip_versions[am.GC_HWIP], bases={i: tuple(getattr(self.ip_off, f'GC_BASE__INST{i}_SEG{s}', 0) for s in range(6)) for i in range(6)}) @@ -1017,7 +1011,6 @@ class AMDDevice(HCQCompiled): self.pmc_enabled:bool = PROFILE > 0 and PMC > 0 if self.pmc_enabled: - if self.target[0] not in {9, 11, 12}: raise RuntimeError(f'PMC are not supported on gc:{self.target}') self.iface.require_profile_mode() self.pmc_sched:list[PMCSample] = [] @@ -1036,7 +1029,6 @@ class AMDDevice(HCQCompiled): # SQTT is disabled by default because of runtime overhead and big file sizes (~200mb to Tensor.full() two 4096x4096 tensors and matmul them) self.sqtt_enabled:bool = PROFILE > 0 and SQTT > 0 if self.sqtt_enabled: - if self.target[0] not in {9, 11, 12}: raise RuntimeError(f'SQ Thread Tracing is not supported on gc:{self.target}') self.iface.require_profile_mode() SQTT_BUFFER_SIZE = getenv("SQTT_BUFFER_SIZE", 256) # in mb, per shader engine @@ -1074,7 +1066,7 @@ class AMDDevice(HCQCompiled): if self.max_private_segment_size >= private_segment_size: return lanes_per_wave = 64 # wave64 - mem_alignment_size = 256 if self.target >= (11,0,0) else 1024 + mem_alignment_size = 256 if self.target[0] != 9 else 1024 size_per_thread = round_up(private_segment_size, mem_alignment_size // lanes_per_wave) size_per_xcc = size_per_thread * lanes_per_wave * self.iface.props['max_slots_scratch_cu'] * self.cu_cnt self.scratch, ok = self._realloc(getattr(self, 'scratch', None), size_per_xcc * self.xccs) @@ -1082,18 +1074,18 @@ class AMDDevice(HCQCompiled): # NOTE: xcc logic is correct only for GFX9. max_scratch_waves = self.cu_cnt * self.iface.props['max_slots_scratch_cu'] * self.xccs wave_scratch = ceildiv(lanes_per_wave * size_per_thread, mem_alignment_size) - num_waves = (size_per_xcc // (wave_scratch * mem_alignment_size)) // (self.se_cnt if self.target >= (11,0,0) else 1) + num_waves = (size_per_xcc // (wave_scratch * mem_alignment_size)) // (self.se_cnt if self.target[0] != 9 else 1) - tmpring_t = getattr(hsa, f'union_COMPUTE_TMPRING_SIZE{"_GFX"+str(self.target[0]) if self.target[0] >= 11 else ""}_bitfields') + tmpring_t = getattr(hsa, f'union_COMPUTE_TMPRING_SIZE{"_GFX"+str(self.target[0]) if self.target[0] != 9 else ""}_bitfields') self.tmpring_size = int.from_bytes(tmpring_t(WAVES=min(num_waves, max_scratch_waves), WAVESIZE=wave_scratch), 'little') self.max_private_segment_size = private_segment_size if hasattr(self, 'aql_desc'): gfx9_rsrc = {'NUM_FORMAT':hsa.BUF_NUM_FORMAT_UINT, 'DATA_FORMAT':hsa.BUF_DATA_FORMAT_32, 'ELEMENT_SIZE':1, 'INDEX_STRIDE':3} rsrc = {'DST_SEL_X':hsa.SQ_SEL_X, 'DST_SEL_Y':hsa.SQ_SEL_Y, 'DST_SEL_Z':hsa.SQ_SEL_Z, 'DST_SEL_W':hsa.SQ_SEL_W, 'ADD_TID_ENABLE':1, - 'TYPE':hsa.SQ_RSRC_BUF, **(gfx9_rsrc if self.target[0] < 10 else {'FORMAT':hsa.BUF_FORMAT_32_UINT, 'OOB_SELECT':2})} - rsrc1_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD1{"_GFX11" if self.target[0] >= 11 else ""}_bitfields') - rsrc3_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] >= 10 else ""}_bitfields') + 'TYPE':hsa.SQ_RSRC_BUF, **(gfx9_rsrc if self.target[0] == 9 else {'FORMAT':hsa.BUF_FORMAT_32_UINT, 'OOB_SELECT':2})} + rsrc1_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD1{"_GFX11" if self.target[0] != 9 else ""}_bitfields') + rsrc3_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] != 9 else ""}_bitfields') self.aql_desc.scratch_backing_memory_location = int(self.scratch.va_addr) self.aql_desc.scratch_wave64_lane_byte_size = self.max_private_segment_size * lanes_per_wave // 64 diff --git a/tinygrad/runtime/support/amd.py b/tinygrad/runtime/support/amd.py index ca0173e033..978b4a6d3b 100644 --- a/tinygrad/runtime/support/amd.py +++ b/tinygrad/runtime/support/amd.py @@ -62,8 +62,6 @@ def import_soc(ip): # rocm soc headers have more profiling enums than upstream linux return type("SOC", (object,), import_header(f"aqlprofile/linux/{({9: 'vega10', 10: 'navi10', 11: 'soc21', 12: 'soc24'}[ip[0]])}_enum.h", ROCM_URL)) -def import_ip_offsets(ip): return getattr(tinygrad.runtime.autogen.am, f"{'navi' if ip[0] > 9 else 'vega'}_offsets") - def import_pmc(ip) -> dict[str, tuple[str, int]]: res:dict[str, tuple[str, int]] = {}