mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-07-04 13:02:09 +08:00
69 lines
4.5 KiB
Python
69 lines
4.5 KiB
Python
# mypy: disable-error-code="empty-body"
|
|
from __future__ import annotations
|
|
import ctypes
|
|
from typing import Annotated, Literal, TypeAlias
|
|
from tinygrad.runtime.support.c import _IO, _IOW, _IOR, _IOWR
|
|
from tinygrad.runtime.support import c
|
|
import sysconfig
|
|
dll = c.DLL('nvrtc', 'nvrtc', [f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib' for pre in ['opt', 'usr/local']])
|
|
class nvrtcResult(Annotated[int, ctypes.c_uint32], c.Enum): pass
|
|
NVRTC_SUCCESS = nvrtcResult.define('NVRTC_SUCCESS', 0)
|
|
NVRTC_ERROR_OUT_OF_MEMORY = nvrtcResult.define('NVRTC_ERROR_OUT_OF_MEMORY', 1)
|
|
NVRTC_ERROR_PROGRAM_CREATION_FAILURE = nvrtcResult.define('NVRTC_ERROR_PROGRAM_CREATION_FAILURE', 2)
|
|
NVRTC_ERROR_INVALID_INPUT = nvrtcResult.define('NVRTC_ERROR_INVALID_INPUT', 3)
|
|
NVRTC_ERROR_INVALID_PROGRAM = nvrtcResult.define('NVRTC_ERROR_INVALID_PROGRAM', 4)
|
|
NVRTC_ERROR_INVALID_OPTION = nvrtcResult.define('NVRTC_ERROR_INVALID_OPTION', 5)
|
|
NVRTC_ERROR_COMPILATION = nvrtcResult.define('NVRTC_ERROR_COMPILATION', 6)
|
|
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = nvrtcResult.define('NVRTC_ERROR_BUILTIN_OPERATION_FAILURE', 7)
|
|
NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = nvrtcResult.define('NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION', 8)
|
|
NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = nvrtcResult.define('NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION', 9)
|
|
NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = nvrtcResult.define('NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID', 10)
|
|
NVRTC_ERROR_INTERNAL_ERROR = nvrtcResult.define('NVRTC_ERROR_INTERNAL_ERROR', 11)
|
|
|
|
@dll.bind
|
|
def nvrtcGetErrorString(result:nvrtcResult) -> c.POINTER[Annotated[bytes, ctypes.c_char]]: ...
|
|
@dll.bind
|
|
def nvrtcVersion(major:c.POINTER[Annotated[int, ctypes.c_int32]], minor:c.POINTER[Annotated[int, ctypes.c_int32]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetNumSupportedArchs(numArchs:c.POINTER[Annotated[int, ctypes.c_int32]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetSupportedArchs(supportedArchs:c.POINTER[Annotated[int, ctypes.c_int32]]) -> nvrtcResult: ...
|
|
class struct__nvrtcProgram(ctypes.Structure): pass
|
|
nvrtcProgram: TypeAlias = c.POINTER[struct__nvrtcProgram]
|
|
@dll.bind
|
|
def nvrtcCreateProgram(prog:c.POINTER[nvrtcProgram], src:c.POINTER[Annotated[bytes, ctypes.c_char]], name:c.POINTER[Annotated[bytes, ctypes.c_char]], numHeaders:Annotated[int, ctypes.c_int32], headers:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]], includeNames:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcDestroyProgram(prog:c.POINTER[nvrtcProgram]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcCompileProgram(prog:nvrtcProgram, numOptions:Annotated[int, ctypes.c_int32], options:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]]) -> nvrtcResult: ...
|
|
size_t: TypeAlias = Annotated[int, ctypes.c_uint64]
|
|
@dll.bind
|
|
def nvrtcGetPTXSize(prog:nvrtcProgram, ptxSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetPTX(prog:nvrtcProgram, ptx:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetCUBINSize(prog:nvrtcProgram, cubinSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetCUBIN(prog:nvrtcProgram, cubin:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetNVVMSize(prog:nvrtcProgram, nvvmSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetNVVM(prog:nvrtcProgram, nvvm:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetLTOIRSize(prog:nvrtcProgram, LTOIRSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetLTOIR(prog:nvrtcProgram, LTOIR:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetOptiXIRSize(prog:nvrtcProgram, optixirSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetOptiXIR(prog:nvrtcProgram, optixir:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetProgramLogSize(prog:nvrtcProgram, logSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetProgramLog(prog:nvrtcProgram, log:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcAddNameExpression(prog:nvrtcProgram, name_expression:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
|
|
@dll.bind
|
|
def nvrtcGetLoweredName(prog:nvrtcProgram, name_expression:c.POINTER[Annotated[bytes, ctypes.c_char]], lowered_name:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]]) -> nvrtcResult: ...
|
|
c.init_records()
|
|
__DEPRECATED__ = lambda msg: __attribute__((deprecated(msg))) # type: ignore |