diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index ba429cb6f4..dee2ba2f7c 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -298,7 +298,7 @@ def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv)) def to_char_p_p(options: list[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) @functools.cache -def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]): +def init_c_struct_t(fields: tuple[tuple[str, type[ctypes._SimpleCData]], ...]): class CStruct(ctypes.Structure): _pack_, _fields_ = 1, fields return CStruct diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 8a41570fcd..2dd030ae46 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -47,7 +47,8 @@ def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment] def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode()) def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode() -def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t) +def to_struct(*t: int, _type: type[ctypes._SimpleCData] = ctypes.c_ulong): + return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t) def wait_check(cbuf: Any): msg("waitUntilCompleted")(cbuf)