Files
StarPilot/tinygrad_repo/tinygrad/runtime/autogen/nvrtc.py
T
firestar5683 d0e1db6766 StarPilot
2026-03-22 03:15:05 -05:00

69 lines
4.5 KiB
Python

# mypy: disable-error-code="empty-body"
from __future__ import annotations
import ctypes
from typing import Annotated, Literal, TypeAlias
from tinygrad.runtime.support.c import _IO, _IOW, _IOR, _IOWR
from tinygrad.runtime.support import c
import sysconfig
dll = c.DLL('nvrtc', 'nvrtc', [f'/{pre}/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib' for pre in ['opt', 'usr/local']])
class nvrtcResult(Annotated[int, ctypes.c_uint32], c.Enum): pass
NVRTC_SUCCESS = nvrtcResult.define('NVRTC_SUCCESS', 0)
NVRTC_ERROR_OUT_OF_MEMORY = nvrtcResult.define('NVRTC_ERROR_OUT_OF_MEMORY', 1)
NVRTC_ERROR_PROGRAM_CREATION_FAILURE = nvrtcResult.define('NVRTC_ERROR_PROGRAM_CREATION_FAILURE', 2)
NVRTC_ERROR_INVALID_INPUT = nvrtcResult.define('NVRTC_ERROR_INVALID_INPUT', 3)
NVRTC_ERROR_INVALID_PROGRAM = nvrtcResult.define('NVRTC_ERROR_INVALID_PROGRAM', 4)
NVRTC_ERROR_INVALID_OPTION = nvrtcResult.define('NVRTC_ERROR_INVALID_OPTION', 5)
NVRTC_ERROR_COMPILATION = nvrtcResult.define('NVRTC_ERROR_COMPILATION', 6)
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = nvrtcResult.define('NVRTC_ERROR_BUILTIN_OPERATION_FAILURE', 7)
NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = nvrtcResult.define('NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION', 8)
NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = nvrtcResult.define('NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION', 9)
NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = nvrtcResult.define('NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID', 10)
NVRTC_ERROR_INTERNAL_ERROR = nvrtcResult.define('NVRTC_ERROR_INTERNAL_ERROR', 11)
@dll.bind
def nvrtcGetErrorString(result:nvrtcResult) -> c.POINTER[Annotated[bytes, ctypes.c_char]]: ...
@dll.bind
def nvrtcVersion(major:c.POINTER[Annotated[int, ctypes.c_int32]], minor:c.POINTER[Annotated[int, ctypes.c_int32]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetNumSupportedArchs(numArchs:c.POINTER[Annotated[int, ctypes.c_int32]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetSupportedArchs(supportedArchs:c.POINTER[Annotated[int, ctypes.c_int32]]) -> nvrtcResult: ...
class struct__nvrtcProgram(ctypes.Structure): pass
nvrtcProgram: TypeAlias = c.POINTER[struct__nvrtcProgram]
@dll.bind
def nvrtcCreateProgram(prog:c.POINTER[nvrtcProgram], src:c.POINTER[Annotated[bytes, ctypes.c_char]], name:c.POINTER[Annotated[bytes, ctypes.c_char]], numHeaders:Annotated[int, ctypes.c_int32], headers:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]], includeNames:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]]) -> nvrtcResult: ...
@dll.bind
def nvrtcDestroyProgram(prog:c.POINTER[nvrtcProgram]) -> nvrtcResult: ...
@dll.bind
def nvrtcCompileProgram(prog:nvrtcProgram, numOptions:Annotated[int, ctypes.c_int32], options:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]]) -> nvrtcResult: ...
size_t: TypeAlias = Annotated[int, ctypes.c_uint64]
@dll.bind
def nvrtcGetPTXSize(prog:nvrtcProgram, ptxSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetPTX(prog:nvrtcProgram, ptx:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetCUBINSize(prog:nvrtcProgram, cubinSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetCUBIN(prog:nvrtcProgram, cubin:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetNVVMSize(prog:nvrtcProgram, nvvmSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetNVVM(prog:nvrtcProgram, nvvm:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetLTOIRSize(prog:nvrtcProgram, LTOIRSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetLTOIR(prog:nvrtcProgram, LTOIR:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetOptiXIRSize(prog:nvrtcProgram, optixirSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetOptiXIR(prog:nvrtcProgram, optixir:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetProgramLogSize(prog:nvrtcProgram, logSizeRet:c.POINTER[size_t]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetProgramLog(prog:nvrtcProgram, log:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
@dll.bind
def nvrtcAddNameExpression(prog:nvrtcProgram, name_expression:c.POINTER[Annotated[bytes, ctypes.c_char]]) -> nvrtcResult: ...
@dll.bind
def nvrtcGetLoweredName(prog:nvrtcProgram, name_expression:c.POINTER[Annotated[bytes, ctypes.c_char]], lowered_name:c.POINTER[c.POINTER[Annotated[bytes, ctypes.c_char]]]) -> nvrtcResult: ...
c.init_records()
__DEPRECATED__ = lambda msg: __attribute__((deprecated(msg))) # type: ignore