diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 3ee72f3f17..cb276daa2f 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -204,9 +204,9 @@ class AMDev(PCIDevImplBase): if reg > len(self.mmio): self.indirect_wreg(reg * 4, val) else: self.mmio[reg] = val - def wreg_pair(self, reg_base:str, lo_suffix:str, hi_suffix:str, val:int): - self.reg(f"{reg_base}{lo_suffix}").write(val & 0xffffffff) - self.reg(f"{reg_base}{hi_suffix}").write(val >> 32) + def wreg_pair(self, reg_base:str, lo_suffix:str, hi_suffix:str, val:int, inst:int=0): + self.reg(f"{reg_base}{lo_suffix}").write(val & 0xffffffff, inst=inst) + self.reg(f"{reg_base}{hi_suffix}").write(val >> 32, inst=inst) def indirect_rreg(self, reg:int) -> int: self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index bca2f65ab5..0863b183d1 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -26,6 +26,8 @@ class AM_SOC(AM_IP): class AM_GMC(AM_IP): def init_sw(self): + self.vmhubs = len(self.adev.regs_offset[am.MMHUB_HWIP]) + # Memory controller aperture self.mc_base = (self.adev.regMMMC_VM_FB_LOCATION_BASE.read() & 0xFFFFFF) << 24 self.mc_end = self.mc_base + self.adev.mm.vram_size - 1 @@ -43,7 +45,7 @@ class AM_GMC(AM_IP): self.pf_status_reg = lambda ip: f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS{'_LO32' if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else ''}" - def init_hw(self): self.init_hub("MM") + def init_hw(self): self.init_hub("MM", inst_cnt=self.vmhubs) def flush_hdp(self): self.adev.wreg(self.adev.reg("regBIF_BX0_REMAP_HDP_MEM_FLUSH_CNTL").read() // 4, 0x0) def flush_tlb(self, ip:Literal["MM", "GC"], vmid, flush_type=0): @@ -66,10 +68,10 @@ class AM_GMC(AM_IP): # Read back the register to ensure the invalidation is complete self.adev.regMMVM_L2_BANK_SELECT_RESERVED_CID2.read() - def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid): - self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12) - self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12) - self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1) + def enable_vm_addressing(self, page_table, ip:Literal["MM", "GC"], vmid, inst): + self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12, inst=inst) + self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12, inst=inst) + self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1, inst=inst) self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1800000, pde0_protection_fault_enable_interrupt=1, pde0_protection_fault_enable_default=1, dummy_page_protection_fault_enable_interrupt=1, dummy_page_protection_fault_enable_default=1, range_protection_fault_enable_interrupt=1, range_protection_fault_enable_default=1, @@ -77,40 +79,42 @@ class AM_GMC(AM_IP): read_protection_fault_enable_interrupt=1, read_protection_fault_enable_default=1, write_protection_fault_enable_interrupt=1, write_protection_fault_enable_default=1, execute_protection_fault_enable_interrupt=1, execute_protection_fault_enable_default=1, - enable_context=1, page_table_depth=(3 - page_table.lv)) + enable_context=1, page_table_depth=(3 - page_table.lv), inst=inst) - def init_hub(self, ip:Literal["MM", "GC"]): + def init_hub(self, ip:Literal["MM", "GC"], inst_cnt:int): # Init system apertures - self.adev.reg(f"reg{ip}MC_VM_AGP_BASE").write(0) - self.adev.reg(f"reg{ip}MC_VM_AGP_BOT").write(0xffffffffffff >> 24) # disable AGP - self.adev.reg(f"reg{ip}MC_VM_AGP_TOP").write(0) + for inst in range(inst_cnt): + self.adev.reg(f"reg{ip}MC_VM_AGP_BASE").write(0, inst=inst) + self.adev.reg(f"reg{ip}MC_VM_AGP_BOT").write(0xffffffffffff >> 24, inst=inst) # disable AGP + self.adev.reg(f"reg{ip}MC_VM_AGP_TOP").write(0, inst=inst) - self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18) - self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18) - self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12) - self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12) + self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_LOW_ADDR").write(self.mc_base >> 18, inst=inst) + self.adev.reg(f"reg{ip}MC_VM_SYSTEM_APERTURE_HIGH_ADDR").write(self.mc_end >> 18, inst=inst) + self.adev.wreg_pair(f"reg{ip}MC_VM_SYSTEM_APERTURE_DEFAULT_ADDR", "_LSB", "_MSB", self.memscratch_paddr >> 12, inst=inst) + self.adev.wreg_pair(f"reg{ip}VM_L2_PROTECTION_FAULT_DEFAULT_ADDR", "_LO32", "_HI32", self.dummy_page_paddr >> 12, inst=inst) - self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1) + self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_CNTL2").update(active_page_migration_pte_read_retry=1, inst=inst) - # Init TLB and cache - self.adev.reg(f"reg{ip}MC_VM_MX_L1_TLB_CNTL").update(enable_l1_tlb=1, system_access_mode=3, enable_advanced_driver_model=1, - system_aperture_unmapped_access=0, eco_bits=0, mtype=self.adev.soc.module.MTYPE_UC) + # Init TLB and cache + self.adev.reg(f"reg{ip}MC_VM_MX_L1_TLB_CNTL").update(enable_l1_tlb=1, system_access_mode=3, enable_advanced_driver_model=1, + system_aperture_unmapped_access=0, eco_bits=0, mtype=self.adev.soc.module.MTYPE_UC, inst=inst) - self.adev.reg(f"reg{ip}VM_L2_CNTL").update(enable_l2_cache=1, enable_l2_fragment_processing=0, enable_default_page_out_to_system_memory=1, - l2_pde0_cache_tag_generation_mode=0, pde_fault_classification=0, context1_identity_access_mode=1, identity_mode_fragment_size=0) - self.adev.reg(f"reg{ip}VM_L2_CNTL2").update(invalidate_all_l1_tlbs=1, invalidate_l2_cache=1) - self.adev.reg(f"reg{ip}VM_L2_CNTL3").write(bank_select=9, l2_cache_bigk_fragment_size=6,l2_cache_4k_associativity=1,l2_cache_bigk_associativity=1) - self.adev.reg(f"reg{ip}VM_L2_CNTL4").write(l2_cache_4k_partition_count=1) - self.adev.reg(f"reg{ip}VM_L2_CNTL5").write(walker_priority_client_id=0x1ff) + self.adev.reg(f"reg{ip}VM_L2_CNTL").update(enable_l2_cache=1, enable_l2_fragment_processing=0, enable_default_page_out_to_system_memory=1, + l2_pde0_cache_tag_generation_mode=0, pde_fault_classification=0, context1_identity_access_mode=1, identity_mode_fragment_size=0, inst=inst) + self.adev.reg(f"reg{ip}VM_L2_CNTL2").update(invalidate_all_l1_tlbs=1, invalidate_l2_cache=1, inst=inst) + self.adev.reg(f"reg{ip}VM_L2_CNTL3").write(bank_select=9, l2_cache_bigk_fragment_size=6,l2_cache_4k_associativity=1, + l2_cache_bigk_associativity=1, inst=inst) + self.adev.reg(f"reg{ip}VM_L2_CNTL4").write(l2_cache_4k_partition_count=1, inst=inst) + self.adev.reg(f"reg{ip}VM_L2_CNTL5").write(walker_priority_client_id=0x1ff, inst=inst) - self.enable_vm_addressing(self.adev.mm.root_page_table, ip, vmid=0) + self.enable_vm_addressing(self.adev.mm.root_page_table, ip, vmid=0, inst=inst) - # Disable identity aperture - self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT1_IDENTITY_APERTURE_LOW_ADDR", "_LO32", "_HI32", 0xfffffffff) - self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT1_IDENTITY_APERTURE_HIGH_ADDR", "_LO32", "_HI32", 0x0) - self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT_IDENTITY_PHYSICAL_OFFSET", "_LO32", "_HI32", 0x0) + # Disable identity aperture + self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT1_IDENTITY_APERTURE_LOW_ADDR", "_LO32", "_HI32", 0xfffffffff, inst=inst) + self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT1_IDENTITY_APERTURE_HIGH_ADDR", "_LO32", "_HI32", 0x0, inst=inst) + self.adev.wreg_pair(f"reg{ip}VM_L2_CONTEXT_IDENTITY_PHYSICAL_OFFSET", "_LO32", "_HI32", 0x0, inst=inst) - for eng_i in range(18): self.adev.wreg_pair(f"reg{ip}VM_INVALIDATE_ENG{eng_i}_ADDR_RANGE", "_LO32", "_HI32", 0x1fffffffff) + for eng_i in range(18): self.adev.wreg_pair(f"reg{ip}VM_INVALIDATE_ENG{eng_i}_ADDR_RANGE", "_LO32", "_HI32", 0x1fffffffff, inst=inst) self.hub_initted[ip] = True @functools.cache # pylint: disable=method-cache-max-size-none @@ -180,40 +184,42 @@ class AM_SMU(AM_IP): return (self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).read() if read_back_arg else None class AM_GFX(AM_IP): + def init_sw(self): self.xccs = len(self.adev.regs_offset[am.GC_HWIP]) + def init_hw(self): # Wait for RLC autoload to complete while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass self._config_gfx_rs64() - self.adev.gmc.init_hub("GC") + self.adev.gmc.init_hub("GC", inst_cnt=self.xccs) # NOTE: Golden reg for gfx11. No values for this reg provided. The kernel just ors 0x20000000 to this reg. - self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000) + for xcc in range(self.xccs): self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000, inst=xcc) - self.adev.regRLC_SRM_CNTL.update(srm_enable=1, auto_incr_addr=1) + for xcc in range(self.xccs): self.adev.regRLC_SRM_CNTL.update(srm_enable=1, auto_incr_addr=1, inst=xcc) self.adev.soc.doorbell_enable(port=0, awid=0x3, awaddr_31_28_value=0x3) self.adev.soc.doorbell_enable(port=3, awid=0x6, awaddr_31_28_value=0x3) - self.adev.regGRBM_CNTL.update(read_timeout=0xff) - for i in range(0, 16): - self._grbm_select(vmid=i) - self.adev.regSH_MEM_CONFIG.write(address_mode=self.adev.soc.module.SH_MEM_ADDRESS_MODE_64, - alignment_mode=self.adev.soc.module.SH_MEM_ALIGNMENT_MODE_UNALIGNED, initial_inst_prefetch=3) + for xcc in range(self.xccs): + self.adev.regGRBM_CNTL.update(read_timeout=0xff, inst=xcc) + for i in range(0, 16): + self._grbm_select(vmid=i, inst=xcc) + self.adev.regSH_MEM_CONFIG.write(address_mode=self.adev.soc.module.SH_MEM_ADDRESS_MODE_64, + alignment_mode=self.adev.soc.module.SH_MEM_ALIGNMENT_MODE_UNALIGNED, initial_inst_prefetch=3, inst=xcc) - # Configure apertures: - # LDS: 0x10000000'00000000 - 0x10000001'00000000 (4GB) - # Scratch: 0x20000000'00000000 - 0x20000001'00000000 (4GB) - self.adev.regSH_MEM_BASES.write(shared_base=0x1, private_base=0x2) - self._grbm_select() + # Configure apertures: + # LDS: 0x10000000'00000000 - 0x10000001'00000000 (4GB) + # Scratch: 0x20000000'00000000 - 0x20000001'00000000 (4GB) + self.adev.regSH_MEM_BASES.write(shared_base=0x1, private_base=0x2, inst=xcc) + self._grbm_select(inst=xcc) - # Configure MEC doorbell range - self.adev.regCP_MEC_DOORBELL_RANGE_LOWER.write(0x0) - self.adev.regCP_MEC_DOORBELL_RANGE_UPPER.write(0x450) - - # Enable MEC - self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=0, mec_pipe0_reset=0, mec_pipe0_active=1, mec_halt=0) + # Configure MEC doorbell range + self.adev.regCP_MEC_DOORBELL_RANGE_LOWER.write(0x0, inst=xcc) + self.adev.regCP_MEC_DOORBELL_RANGE_UPPER.write(0x450, inst=xcc) + # Enable MEC + self.adev.regCP_MEC_RS64_CNTL.update(mec_invalidate_icache=0, mec_pipe0_reset=0, mec_pipe0_active=1, mec_halt=0, inst=xcc) # NOTE: Wait for MEC to be ready. The kernel does udelay here as well. time.sleep(0.05) @@ -261,36 +267,39 @@ class AM_GFX(AM_IP): def set_clockgating_state(self): if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1) - self.adev.regRLC_SAFE_MODE.write(message=1, cmd=1) - wait_cond(lambda: self.adev.regRLC_SAFE_MODE.read() & 0x1, value=0, msg="RLC safe mode timeout") + for xcc in range(self.xccs): + self.adev.regRLC_SAFE_MODE.write(message=1, cmd=1, inst=xcc) + wait_cond(lambda: self.adev.regRLC_SAFE_MODE.read(inst=xcc) & 0x1, value=0, msg="RLC safe mode timeout") - self.adev.regRLC_CGCG_CGLS_CTRL.update(cgcg_gfx_idle_threshold=0x36, cgcg_en=1, cgls_rep_compansat_delay=0xf, cgls_en=1) + self.adev.regRLC_CGCG_CGLS_CTRL.update(cgcg_gfx_idle_threshold=0x36, cgcg_en=1, cgls_rep_compansat_delay=0xf, cgls_en=1, inst=xcc) - self.adev.regCP_RB_WPTR_POLL_CNTL.update(poll_frequency=0x100, idle_poll_count=0x90) - self.adev.regCP_INT_CNTL.update(cntx_busy_int_enable=1, cntx_empty_int_enable=1, cmp_busy_int_enable=1, gfx_idle_int_enable=1) - self.adev.regSDMA0_RLC_CGCG_CTRL.update(cgcg_int_enable=1) - self.adev.regSDMA1_RLC_CGCG_CTRL.update(cgcg_int_enable=1) + self.adev.regCP_RB_WPTR_POLL_CNTL.update(poll_frequency=0x100, idle_poll_count=0x90, inst=xcc) + self.adev.regCP_INT_CNTL.update(cntx_busy_int_enable=1, cntx_empty_int_enable=1, cmp_busy_int_enable=1, gfx_idle_int_enable=1, inst=xcc) + self.adev.regSDMA0_RLC_CGCG_CTRL.update(cgcg_int_enable=1, inst=xcc) + self.adev.regSDMA1_RLC_CGCG_CTRL.update(cgcg_int_enable=1, inst=xcc) - self.adev.regRLC_CGTT_MGCG_OVERRIDE.update(perfmon_clock_state=1, gfxip_fgcg_override=0, gfxip_repeater_fgcg_override=0, - grbm_cgtt_sclk_override=0, rlc_cgtt_sclk_override=0, gfxip_mgcg_override=0, gfxip_cgls_override=0, gfxip_cgcg_override=0) + self.adev.regRLC_CGTT_MGCG_OVERRIDE.update(perfmon_clock_state=1, gfxip_fgcg_override=0, gfxip_repeater_fgcg_override=0, + grbm_cgtt_sclk_override=0, rlc_cgtt_sclk_override=0, gfxip_mgcg_override=0, gfxip_cgls_override=0, gfxip_cgcg_override=0, inst=xcc) - self.adev.regRLC_SAFE_MODE.write(message=0, cmd=1) + self.adev.regRLC_SAFE_MODE.write(message=0, cmd=1, inst=xcc) - def _grbm_select(self, me=0, pipe=0, queue=0, vmid=0): self.adev.regGRBM_GFX_CNTL.write(meid=me, pipeid=pipe, vmid=vmid, queueid=queue) + def _grbm_select(self, me=0, pipe=0, queue=0, vmid=0, inst=0): + self.adev.regGRBM_GFX_CNTL.write(meid=me, pipeid=pipe, vmid=vmid, queueid=queue, inst=inst) def _config_gfx_rs64(self): - def _config_helper(eng_name, cntl_reg, eng_reg, pipe_cnt, me=0): + def _config_helper(eng_name, cntl_reg, eng_reg, pipe_cnt, me=0, xcc=0): for pipe in range(pipe_cnt): - self._grbm_select(me=me, pipe=pipe) - self.adev.wreg_pair(f"regCP_{eng_reg}_PRGRM_CNTR_START", "", "_HI", self.adev.fw.ucode_start[eng_name] >> 2) - self._grbm_select() - self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 1 for pipe in range(pipe_cnt)}) - self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 0 for pipe in range(pipe_cnt)}) + self._grbm_select(me=me, pipe=pipe, inst=xcc) + self.adev.wreg_pair(f"regCP_{eng_reg}_PRGRM_CNTR_START", "", "_HI", self.adev.fw.ucode_start[eng_name] >> 2, inst=xcc) + self._grbm_select(inst=xcc) + self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 1 for pipe in range(pipe_cnt)}, inst=xcc) + self.adev.reg(f"regCP_{cntl_reg}_CNTL").update(**{f"{eng_name.lower()}_pipe{pipe}_reset": 0 for pipe in range(pipe_cnt)}, inst=xcc) - if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0): - _config_helper(eng_name="PFP", cntl_reg="ME", eng_reg="PFP", pipe_cnt=1) - _config_helper(eng_name="ME", cntl_reg="ME", eng_reg="ME", pipe_cnt=1) - _config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1) + for xcc in range(self.adev.gfx.xccs): + if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0): + _config_helper(eng_name="PFP", cntl_reg="ME", eng_reg="PFP", pipe_cnt=1, xcc=xcc) + _config_helper(eng_name="ME", cntl_reg="ME", eng_reg="ME", pipe_cnt=1, xcc=xcc) + _config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1, xcc=xcc) class AM_IH(AM_IP): def init_sw(self):