diff --git a/tinygrad/runtime/autogen/__init__.py b/tinygrad/runtime/autogen/__init__.py index 159058657e..19aff5ad14 100644 --- a/tinygrad/runtime/autogen/__init__.py +++ b/tinygrad/runtime/autogen/__init__.py @@ -17,7 +17,8 @@ llvm_lib = (r"'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll' if WIN else '/opt/homeb clang_lib = "'/opt/homebrew/opt/llvm@20/lib/libclang.dylib' if OSX else ['clang-20', 'clang']" webgpu_lib = "os.path.join(sysconfig.get_paths()['purelib'], 'pydawn', 'lib', 'libwebgpu_dawn.dll') if WIN else 'webgpu_dawn'" -nv_lib_path = "[f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get(\"MULTIARCH\", \"\").rsplit(\"-\", 1)[0]}/lib' for pre in ['opt', 'usr/local']]" +nv_lib_path = ("[f'/{pre}/cuda/targets/{tgt}/lib' for pre in ['opt', 'usr/local'] for tgt in " + "[sysconfig.get_config_vars().get(\"MULTIARCH\", \"\").rsplit(\"-\", 1)[0], 'sbsa-linux']]") def load(name, dll, files, **kwargs): if not (f:=(root/(path:=kwargs.pop("path", __name__)).replace('.','/')/f"{name}.py")).exists() or getenv('REGEN'): diff --git a/tinygrad/runtime/autogen/nvjitlink.py b/tinygrad/runtime/autogen/nvjitlink.py index 60f1b780ae..eb9a013a8e 100644 --- a/tinygrad/runtime/autogen/nvjitlink.py +++ b/tinygrad/runtime/autogen/nvjitlink.py @@ -5,7 +5,7 @@ from typing import Literal, TypeAlias from tinygrad.runtime.support.c import _IO, _IOW, _IOR, _IOWR from tinygrad.runtime.support import c import sysconfig -dll = c.DLL('nvjitlink', 'nvJitLink', [f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib' for pre in ['opt', 'usr/local']]) +dll = c.DLL('nvjitlink', 'nvJitLink', [f'/{pre}/cuda/targets/{tgt}/lib' for pre in ['opt', 'usr/local'] for tgt in [sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0], 'sbsa-linux']]) nvJitLinkResult: dict[int, str] = {(NVJITLINK_SUCCESS:=0): 'NVJITLINK_SUCCESS', (NVJITLINK_ERROR_UNRECOGNIZED_OPTION:=1): 'NVJITLINK_ERROR_UNRECOGNIZED_OPTION', (NVJITLINK_ERROR_MISSING_ARCH:=2): 'NVJITLINK_ERROR_MISSING_ARCH', (NVJITLINK_ERROR_INVALID_INPUT:=3): 'NVJITLINK_ERROR_INVALID_INPUT', (NVJITLINK_ERROR_PTX_COMPILE:=4): 'NVJITLINK_ERROR_PTX_COMPILE', (NVJITLINK_ERROR_NVVM_COMPILE:=5): 'NVJITLINK_ERROR_NVVM_COMPILE', (NVJITLINK_ERROR_INTERNAL:=6): 'NVJITLINK_ERROR_INTERNAL'} nvJitLinkInputType: dict[int, str] = {(NVJITLINK_INPUT_NONE:=0): 'NVJITLINK_INPUT_NONE', (NVJITLINK_INPUT_CUBIN:=1): 'NVJITLINK_INPUT_CUBIN', (NVJITLINK_INPUT_PTX:=2): 'NVJITLINK_INPUT_PTX', (NVJITLINK_INPUT_LTOIR:=3): 'NVJITLINK_INPUT_LTOIR', (NVJITLINK_INPUT_FATBIN:=4): 'NVJITLINK_INPUT_FATBIN', (NVJITLINK_INPUT_OBJECT:=5): 'NVJITLINK_INPUT_OBJECT', (NVJITLINK_INPUT_LIBRARY:=6): 'NVJITLINK_INPUT_LIBRARY'} class struct_nvJitLink(c.Struct): pass diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py index 3f8f390bef..b8e50c9fba 100644 --- a/tinygrad/runtime/autogen/nvrtc.py +++ b/tinygrad/runtime/autogen/nvrtc.py @@ -5,7 +5,7 @@ from typing import Literal, TypeAlias from tinygrad.runtime.support.c import _IO, _IOW, _IOR, _IOWR from tinygrad.runtime.support import c import sysconfig -dll = c.DLL('nvrtc', 'nvrtc', [f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib' for pre in ['opt', 'usr/local']]) +dll = c.DLL('nvrtc', 'nvrtc', [f'/{pre}/cuda/targets/{tgt}/lib' for pre in ['opt', 'usr/local'] for tgt in [sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0], 'sbsa-linux']]) nvrtcResult: dict[int, str] = {(NVRTC_SUCCESS:=0): 'NVRTC_SUCCESS', (NVRTC_ERROR_OUT_OF_MEMORY:=1): 'NVRTC_ERROR_OUT_OF_MEMORY', (NVRTC_ERROR_PROGRAM_CREATION_FAILURE:=2): 'NVRTC_ERROR_PROGRAM_CREATION_FAILURE', (NVRTC_ERROR_INVALID_INPUT:=3): 'NVRTC_ERROR_INVALID_INPUT', (NVRTC_ERROR_INVALID_PROGRAM:=4): 'NVRTC_ERROR_INVALID_PROGRAM', (NVRTC_ERROR_INVALID_OPTION:=5): 'NVRTC_ERROR_INVALID_OPTION', (NVRTC_ERROR_COMPILATION:=6): 'NVRTC_ERROR_COMPILATION', (NVRTC_ERROR_BUILTIN_OPERATION_FAILURE:=7): 'NVRTC_ERROR_BUILTIN_OPERATION_FAILURE', (NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION:=8): 'NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION', (NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION:=9): 'NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION', (NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID:=10): 'NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID', (NVRTC_ERROR_INTERNAL_ERROR:=11): 'NVRTC_ERROR_INTERNAL_ERROR'} @dll.bind(c.POINTER[ctypes.c_char], ctypes.c_uint32) def nvrtcGetErrorString(result:ctypes.c_uint32) -> c.POINTER[ctypes.c_char]: ...