mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
DGX Spark and Jetson Thor support (#15939)
This commit is contained in:
committed by
GitHub
parent
5eb1fd5d3c
commit
b36010c55a
@@ -17,7 +17,8 @@ llvm_lib = (r"'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll' if WIN else '/opt/homeb
|
||||
clang_lib = "'/opt/homebrew/opt/llvm@20/lib/libclang.dylib' if OSX else ['clang-20', 'clang']"
|
||||
|
||||
webgpu_lib = "os.path.join(sysconfig.get_paths()['purelib'], 'pydawn', 'lib', 'libwebgpu_dawn.dll') if WIN else 'webgpu_dawn'"
|
||||
nv_lib_path = "[f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get(\"MULTIARCH\", \"\").rsplit(\"-\", 1)[0]}/lib' for pre in ['opt', 'usr/local']]"
|
||||
nv_lib_path = ("[f'/{pre}/cuda/targets/{tgt}/lib' for pre in ['opt', 'usr/local'] for tgt in "
|
||||
"[sysconfig.get_config_vars().get(\"MULTIARCH\", \"\").rsplit(\"-\", 1)[0], 'sbsa-linux']]")
|
||||
|
||||
def load(name, dll, files, **kwargs):
|
||||
if not (f:=(root/(path:=kwargs.pop("path", __name__)).replace('.','/')/f"{name}.py")).exists() or getenv('REGEN'):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Literal, TypeAlias
|
||||
from tinygrad.runtime.support.c import _IO, _IOW, _IOR, _IOWR
|
||||
from tinygrad.runtime.support import c
|
||||
import sysconfig
|
||||
dll = c.DLL('nvjitlink', 'nvJitLink', [f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib' for pre in ['opt', 'usr/local']])
|
||||
dll = c.DLL('nvjitlink', 'nvJitLink', [f'/{pre}/cuda/targets/{tgt}/lib' for pre in ['opt', 'usr/local'] for tgt in [sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0], 'sbsa-linux']])
|
||||
nvJitLinkResult: dict[int, str] = {(NVJITLINK_SUCCESS:=0): 'NVJITLINK_SUCCESS', (NVJITLINK_ERROR_UNRECOGNIZED_OPTION:=1): 'NVJITLINK_ERROR_UNRECOGNIZED_OPTION', (NVJITLINK_ERROR_MISSING_ARCH:=2): 'NVJITLINK_ERROR_MISSING_ARCH', (NVJITLINK_ERROR_INVALID_INPUT:=3): 'NVJITLINK_ERROR_INVALID_INPUT', (NVJITLINK_ERROR_PTX_COMPILE:=4): 'NVJITLINK_ERROR_PTX_COMPILE', (NVJITLINK_ERROR_NVVM_COMPILE:=5): 'NVJITLINK_ERROR_NVVM_COMPILE', (NVJITLINK_ERROR_INTERNAL:=6): 'NVJITLINK_ERROR_INTERNAL'}
|
||||
nvJitLinkInputType: dict[int, str] = {(NVJITLINK_INPUT_NONE:=0): 'NVJITLINK_INPUT_NONE', (NVJITLINK_INPUT_CUBIN:=1): 'NVJITLINK_INPUT_CUBIN', (NVJITLINK_INPUT_PTX:=2): 'NVJITLINK_INPUT_PTX', (NVJITLINK_INPUT_LTOIR:=3): 'NVJITLINK_INPUT_LTOIR', (NVJITLINK_INPUT_FATBIN:=4): 'NVJITLINK_INPUT_FATBIN', (NVJITLINK_INPUT_OBJECT:=5): 'NVJITLINK_INPUT_OBJECT', (NVJITLINK_INPUT_LIBRARY:=6): 'NVJITLINK_INPUT_LIBRARY'}
|
||||
class struct_nvJitLink(c.Struct): pass
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Literal, TypeAlias
|
||||
from tinygrad.runtime.support.c import _IO, _IOW, _IOR, _IOWR
|
||||
from tinygrad.runtime.support import c
|
||||
import sysconfig
|
||||
dll = c.DLL('nvrtc', 'nvrtc', [f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib' for pre in ['opt', 'usr/local']])
|
||||
dll = c.DLL('nvrtc', 'nvrtc', [f'/{pre}/cuda/targets/{tgt}/lib' for pre in ['opt', 'usr/local'] for tgt in [sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0], 'sbsa-linux']])
|
||||
nvrtcResult: dict[int, str] = {(NVRTC_SUCCESS:=0): 'NVRTC_SUCCESS', (NVRTC_ERROR_OUT_OF_MEMORY:=1): 'NVRTC_ERROR_OUT_OF_MEMORY', (NVRTC_ERROR_PROGRAM_CREATION_FAILURE:=2): 'NVRTC_ERROR_PROGRAM_CREATION_FAILURE', (NVRTC_ERROR_INVALID_INPUT:=3): 'NVRTC_ERROR_INVALID_INPUT', (NVRTC_ERROR_INVALID_PROGRAM:=4): 'NVRTC_ERROR_INVALID_PROGRAM', (NVRTC_ERROR_INVALID_OPTION:=5): 'NVRTC_ERROR_INVALID_OPTION', (NVRTC_ERROR_COMPILATION:=6): 'NVRTC_ERROR_COMPILATION', (NVRTC_ERROR_BUILTIN_OPERATION_FAILURE:=7): 'NVRTC_ERROR_BUILTIN_OPERATION_FAILURE', (NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION:=8): 'NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION', (NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION:=9): 'NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION', (NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID:=10): 'NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID', (NVRTC_ERROR_INTERNAL_ERROR:=11): 'NVRTC_ERROR_INTERNAL_ERROR'}
|
||||
@dll.bind(c.POINTER[ctypes.c_char], ctypes.c_uint32)
|
||||
def nvrtcGetErrorString(result:ctypes.c_uint32) -> c.POINTER[ctypes.c_char]: ...
|
||||
|
||||
Reference in New Issue
Block a user