From dd3a720c5a455af5a0ba3f0bcdafca243db9f7c7 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Thu, 2 Oct 2025 10:18:08 -0700 Subject: [PATCH] save more lines --- tinygrad/renderer/nir.py | 18 ++++++------------ tinygrad/runtime/support/compiler_mesa.py | 17 ++++++----------- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 8fb6cfd493..2e399eb3ba 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -179,14 +179,6 @@ class NIRRenderer(Renderer): (UPat(Ops.ENDIF, name="x"), lambda ctx,x: ensure(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]]))) ]) - def __init__(self, device): - self.device = device - mesa.glsl_type_singleton_init_or_ref() - - def __del__(self): - try: mesa.glsl_type_singleton_decref() - except AttributeError: pass - @property def nir_options(self): raise NotImplementedError("needs nir_options") def param(self, dtype:DType, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param") @@ -194,6 +186,7 @@ class NIRRenderer(Renderer): self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None) def render(self, uops:list[UOp]): + mesa.glsl_type_singleton_init_or_ref() self.prerender(uops) for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg self.r, self.param_idx, ranges = {}, 0, [] @@ -231,14 +224,16 @@ class NIRRenderer(Renderer): mesa.ralloc_free(self.b.shader) ctypes.CDLL(ctypes.util.find_library('c')).free(blob.data) del self.b, self.r + mesa.glsl_type_singleton_decref() return ret class NAKRenderer(NIRRenderer): - def __init__(self, dev=None, nir_options=None, device="NV"): + device = "NV" + + def __init__(self, dev=None, nir_options=None): if dev: self.dev = dev else: self.__dict__['nir_options'] = nir_options - super().__init__(device) @classmethod def with_opts(cls, opts): return cls(nir_options=opts) @@ -262,13 +257,12 @@ class NAKRenderer(NIRRenderer): return d(intrin) class LVPRenderer(NIRRenderer): + device = "CPU" has_local = False has_shared = False global_max = (1, 0, 0) nir_options = mesa.lvp_nir_options - def __init__(self, device="CPU"): super().__init__(device) - def param(self, dtype:DType, sz:int) -> mesa.nir_def: intrin = mesa.nir_intrinsic_instr_create(self.b.shader, mesa.nir_intrinsic_load_ubo) intrin.contents.num_components = 1 diff --git a/tinygrad/runtime/support/compiler_mesa.py b/tinygrad/runtime/support/compiler_mesa.py index fdbcda74a6..60ed7fd76a 100644 --- a/tinygrad/runtime/support/compiler_mesa.py +++ b/tinygrad/runtime/support/compiler_mesa.py @@ -8,8 +8,7 @@ try: import tinygrad.runtime.autogen.llvm as llvm except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment] def deserialize(enc_src, opts): - blobreader = mesa.struct_blob_reader() - mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src)) + mesa.blob_reader_init(blobreader:=mesa.struct_blob_reader(), src:=base64.b64decode(enc_src), len(src)) return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader) class LVPCompiler(Compiler): @@ -48,16 +47,14 @@ class LVPCompiler(Compiler): def disassemble(self, lib:bytes): cpu_objdump(lib) class NAKCompiler(Compiler): - def __init__(self, arch, warps_per_sm, cache_key="nak"): - self.arch, self.warps_per_sm = arch, warps_per_sm - self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm)) + def __init__(self, arch, warps, cache_key="nak"): + self.arch, self.warps, self.cc = arch, warps, mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps)) self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents) - mesa.glsl_type_singleton_init_or_ref() super().__init__(f"compile_{cache_key}_{arch}") - def __del__(self): - mesa.nak_compiler_destroy(self.cc) - mesa.glsl_type_singleton_decref() + def __del__(self): mesa.nak_compiler_destroy(self.cc) + + def __reduce__(self): return NAKCompiler, (self.arch, self.warps) def compile(self, src) -> bytes: shader = deserialize(src, self.nir_options) @@ -74,8 +71,6 @@ class NAKCompiler(Compiler): print(subprocess.check_output(['nvdisasm', "-b", f"SM{self.arch[3:]}", fn]).decode('utf-8')) except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.") - def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm) - def parse_nak_shader(shader:bytes) -> Tuple[memoryview, int, int, int]: info = mesa.struct_nak_shader_info.from_buffer(shader) return (memoryview(shader[ctypes.sizeof(info):]), info.num_gprs, round_up(info.cs.smem_size, 0x80), round_up(info.slm_size, 0x10))