Files
tinygrad/tinygrad/runtime/support/compiler_mesa.py
2025-12-05 12:36:35 -05:00

87 lines
4.4 KiB
Python

import base64, ctypes, pathlib, tempfile, hashlib
from tinygrad.device import Compiler
from tinygrad.helpers import cpu_objdump, system
from tinygrad.runtime.autogen import mesa
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr
try: from tinygrad.runtime.autogen import llvm
except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment]
def deserialize(enc_src, opts):
blobreader = mesa.struct_blob_reader()
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
class NIRCompiler(Compiler):
def __init__(self, cache_key):
mesa.glsl_type_singleton_init_or_ref()
super().__init__(cache_key)
def __del__(self): mesa.glsl_type_singleton_decref()
class LVPCompiler(CPULLVMCompiler, NIRCompiler):
def __init__(self, cache_key="lvp"):
CPULLVMCompiler.__init__(self)
NIRCompiler.__init__(self, f"compile_{cache_key}")
def __del__(self):
NIRCompiler.__del__(self)
CPULLVMCompiler.__del__(self)
def compile(self, src) -> bytes:
shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext()
gallivm = mesa.gallivm_create(None, mesa.lp_context_ref(ctypes.cast(ctx, ctypes.POINTER(mesa.struct_LLVMOpaqueContext)), True), None).contents
module, builder = ctypes.cast(gallivm.module, llvm.LLVMModuleRef), ctypes.cast(gallivm.builder, llvm.LLVMBuilderRef)
params = mesa.struct_lp_build_tgsi_params(mesa.struct_lp_type(floating=True, sign=True, width=32, length=4),
resources_type=mesa.lp_build_jit_resources_type(gallivm), mask=ctypes.pointer(mesa.struct_lp_build_mask_context()))
pt = llvm.LLVMPointerType(ctypes.cast(params.resources_type, llvm.LLVMTypeRef), 0)
fn = llvm.LLVMAddFunction(module, shader.contents.info.name, llvm.LLVMFunctionType(llvm.LLVMVoidTypeInContext(ctx), pt, 1, 0))
llvm.LLVMPositionBuilderAtEnd(builder, llvm.LLVMAppendBasicBlockInContext(ctx, fn, b"entry"))
params.consts_ptr = mesa.lp_build_struct_get_ptr2(gallivm, params.resources_type,
ctypes.cast(llvm.LLVMGetParam(fn, 0), mesa.LLVMValueRef), mesa.LP_JIT_RES_CONSTANTS, b"constants")
mesa.lp_build_mask_begin(params.mask, gallivm, params.type, mesa.lp_build_one(gallivm, params.type))
mesa.lp_build_mask_end(params.mask)
mesa.lp_build_nir_soa(gallivm, shader, params, None)
llvm.LLVMBuildRetVoid(builder)
mesa.gallivm_verify_function(gallivm, ctypes.cast(fn, mesa.LLVMValueRef))
mesa.lp_passmgr_run(gallivm.passmgr, gallivm.module, ctypes.cast(self.target_machine, mesa.LLVMTargetMachineRef), gallivm.module_name)
obj_buf = expect(llvm.LLVMTargetMachineEmitToMemoryBuffer(self.target_machine, module, llvm.LLVMObjectFile, err:=cerr(),
ctypes.pointer(buf:=llvm.LLVMMemoryBufferRef())), err, buf)
obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf))
mesa.gallivm_destroy(gallivm)
mesa.ralloc_free(shader)
return obj
def disassemble(self, lib: bytes): cpu_objdump(lib)
class NAKCompiler(NIRCompiler):
def __init__(self, arch, warps_per_sm, cache_key="nak"):
self.arch, self.warps_per_sm = arch, warps_per_sm
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
super().__init__(f"compile_{cache_key}_{arch}")
def __del__(self):
mesa.nak_compiler_destroy(self.cc)
super().__del__()
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
def compile(self, src) -> bytes:
shader = deserialize(src, self.nir_options)
mesa.nak_preprocess_nir(shader, self.cc)
ret = bytes((out:=mesa.nak_compile_shader(shader, False, self.cc, 0, None).contents).info) + ctypes.string_at(out.code, out.code_size)
mesa.nak_shader_bin_destroy(out)
mesa.ralloc_free(shader)
return ret
def disassemble(self, lib: bytes):
try:
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinynak_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn, "wb") as f: f.write(lib[ctypes.sizeof(mesa.struct_nak_shader_info):])
print(system(f"nvdisasm -b SM{self.arch[3:]} {fn}"))
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")