#pragma once #include "util.cuh" #include namespace kittens { namespace py { template struct from_object { static T make(pybind11::object obj) { return obj.cast(); } }; template struct from_object { static GL make(pybind11::object obj) { // Check if argument is a torch.Tensor if (pybind11::hasattr(obj, "__class__") && obj.attr("__class__").attr("__name__").cast() == "Tensor") { // Check if tensor is contiguous if (!obj.attr("is_contiguous")().cast()) { throw std::runtime_error("Tensor must be contiguous"); } if (obj.attr("device").attr("type").cast() == "cpu") { throw std::runtime_error("Tensor must be on CUDA device"); } // Get shape, pad with 1s if needed std::array shape = {1, 1, 1, 1}; auto py_shape = obj.attr("shape").cast(); size_t dims = py_shape.size(); if (dims > 4) { throw std::runtime_error("Expected Tensor.ndim <= 4"); } for (size_t i = 0; i < dims; ++i) { shape[4 - dims + i] = pybind11::cast(py_shape[i]); } // Get data pointer using data_ptr() uint64_t data_ptr = obj.attr("data_ptr")().cast(); // Create GL object using make_gl return make_gl(data_ptr, shape[0], shape[1], shape[2], shape[3]); } throw std::runtime_error("Expected a torch.Tensor"); } }; template concept has_dynamic_shared_memory = requires(T t) { { t.dynamic_shared_memory() } -> std::convertible_to; }; template struct trait; template struct trait { using member_type = MT; using type = T; }; template using object = pybind11::object; template static void bind_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) { m.def(name, [](object... args) { TGlobal __g__ {from_object::member_type>::make(args)...}; if constexpr (has_dynamic_shared_memory) { int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory(); hipFuncSetAttribute((void *) kernel, hipFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__); kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__>>>(__g__); } else { kernel<<<__g__.grid(), __g__.block()>>>(__g__); } }); } template static void bind_function(auto m, auto name, auto TGlobal::*... member_ptrs) { m.def(name, [](object... args) { TGlobal __g__ {from_object::member_type>::make(args)...}; function(__g__); }); } } // namespace py } // namespace kittens