/** * @file * @brief General utilities for ThunderKittens. */ #pragma once #include #include #include #include #include #include "base_types.cuh" #ifndef __forceinline__ #define __forceinline__ __attribute__((always_inline)) #endif /** * @namespace kittens * * @brief The main namespace of ThunderKittens. */ namespace kittens { /* ---------- GENERAL CONSTANTS FOR KITTENS ---------- */ /** * @brief Constant representing number of threads in a warp. */ constexpr int WARP_THREADS{64}; /** * @brief Get the warp ID of the current thread. * @return The warp ID. */ __device__ __forceinline__ int warpid() { return threadIdx.x >> 6; } /** * @brief Get the number of warps in the threadblock. * @return The number of warps in the threadblock. */ __device__ __forceinline__ int num_warps() { return blockDim.x / WARP_THREADS; } /** * @brief Get the lane ID of the current thread within its warp. * @return The lane ID. */ __device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; } using i32x4 = int32_t __attribute__((ext_vector_type(4))); struct buffer_resource { uint64_t ptr; uint32_t range; uint32_t config; }; /** * @brief Compute the ceiling division of two integers. * @param a The dividend. * @param b The divisor. * @return The ceiling division result. */ __host__ __device__ inline int ceil_div(int a, int b) { return (a + b - 1) / b; } /** * @brief Transform a workgroup ID to a new workgroup ID based on the chunk size and number of XCDs. * @param workgroup_id The original workgroup ID. * @param num_workgroups The total number of workgroups. * @param num_xcds The number of XCDs. * @param chunk_size The chunk size. * @return The new workgroup ID. */ __host__ __device__ inline int chiplet_transform_chunked( int workgroup_id, int num_workgroups, int num_xcds, int chunk_size ) { // Current XCD int xcd = workgroup_id % num_xcds; // Largest full (NUM_XCDS*CHUNK_SIZE)-aligned block int block = num_xcds * chunk_size; int limit = (num_workgroups / block) * block; // If pid beyond the last full block, leave unchanged if (workgroup_id > limit) return workgroup_id; // Local PID (within round-robin assignment) int local_pid = workgroup_id / num_xcds; int chunk_idx = local_pid / chunk_size; int pos_in_chunk = local_pid % chunk_size; // New PID return chunk_idx * block + xcd * chunk_size + pos_in_chunk; } constexpr int MAX_SHARED_MEMORY = 160000; constexpr int NUM_XCDS = 8; constexpr int CUS_PER_XCD = 32; constexpr int NUM_CUS = CUS_PER_XCD * NUM_XCDS; /* ---------- CUSTOM TYPES ---------- */ typedef uint32_t uint2_t __attribute__((ext_vector_type(2))); /* ---------- TYPE HELPERS ---------- */ /** * @namespace ducks * * @brief ThunderKittens' namespace for template metaprogramming.. * * This includes primarily dummy types and concept wrappers, along * with a few additional utilities. */ namespace ducks { /** * @brief A type representing an empty default for a template. */ struct default_type {}; // This macro can't be done as a template, so it doesn't really have a location in kittens. #define typeof(A) typename std::remove_const::type>::type } /* ---------- SHUFFLE UTILS ---------- */ /** * @brief Mask constant for all active threads in a warp. */ static constexpr uint64_t MASK_ALL = 0xFFFFFFFFFFFFFFFF; /** * @brief Perform a shuffle down operation on a packed type synchronously across a warp. * @tparam T The type of the value to be shuffled. * @param mask[in] The mask of active threads. * @param f[in] The value to be shuffled. * @param delta[in] The number of positions to shuffle down. * @return The result of the shuffle operation. */ template __device__ static inline T packed_shfl_down(uint64_t mask, const T &f, int delta) { if constexpr (std::is_same_v || std::is_same_v) { static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int)); union { __hip_bfloat162 bf162; unsigned int ui; } u; if constexpr (std::is_same_v) { u.bf162 = *reinterpret_cast(&f); } else { u.bf162 = __hip_bfloat162{*reinterpret_cast(&f), *reinterpret_cast(&f)}; } u.ui = __shfl_down_sync(mask, u.ui, delta, 64); if constexpr (std::is_same_v) { return *reinterpret_cast(&u.bf162.x); // Extract single bf16 from the .x component } else { return u.bf162; // Return full bf162 for bf16_2 case } } else { return __shfl_down(f, delta); } } template<> __device__ inline float2 packed_shfl_down(uint64_t mask, const float2 &f, int delta) { float2 r; r.x = __shfl_down(f.x, delta); r.y = __shfl_down(f.y, delta); return r; } /** * @brief Perform a packed shuffle operation synchronously across a warp. * @tparam T The type of the value to be shuffled. * @param mask[in] The mask of active threads. * @param f[in] The value to be shuffled. * @param src[in] The source lane from which to shuffle. * @return The result of the shuffle operation. */ template __device__ static inline T packed_shfl(uint64_t mask, const T &f, int src) { return __shfl(f, src); } template<> __device__ inline bf16 packed_shfl(uint64_t mask, const bf16 &f, int src) { float r = __shfl(base_types::convertor::convert(f), src); return base_types::convertor::convert(r); } template<> __device__ inline bf16_2 packed_shfl(uint64_t mask, const bf16_2 &f, int src) { float2 r; r.x = __shfl(base_types::convertor::convert(f.x), src); r.y = __shfl(base_types::convertor::convert(f.y), src); return base_types::convertor::convert(r); } template<> __device__ inline half packed_shfl(uint64_t mask, const half &f, int src) { float r = __shfl(base_types::convertor::convert(f), src); return base_types::convertor::convert(r); } template<> __device__ inline half_2 packed_shfl(uint64_t mask, const half_2 &f, int src) { float2 r; r.x = __shfl(base_types::convertor::convert(f.x), src); r.y = __shfl(base_types::convertor::convert(f.y), src); return base_types::convertor::convert(r); } template<> __device__ inline float2 packed_shfl(uint64_t mask, const float2 &f, int src) { float2 r; r.x = __shfl(f.x, src); r.y = __shfl(f.y, src); return r; } using bytes_4 = HIP_vector_type; using bytes_8 = HIP_vector_type; using bytes_16 = HIP_vector_type; /* ---------- SHARED MEMORY UTILS ---------- */ // namespace ducks { // namespace sb { // struct identifier {}; // } // } // template // struct sb { // using identifier = ducks::sb::identifier; // Args... args; // }; // namespace ducks { // namespace sb { // template concept all = requires { // typename T::identifier; // } && std::is_same_v; // } // } #define KITTENS_ALIGN_AS(n) alignas(n) #define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(16) /** * @brief Dummy structure for alignment purposes. Needed for WGMMA and TMA calls. */ struct KITTENS_DEFAULT_ALIGN alignment_dummy { int dummy; }; /** * @brief Very simple allocator for dynamic shared memory. Advances pointer and tracks alignments. * @tparam default_alignment The default alignment this allocator will enforce. If <=0 (default -1) it will not align. */ template struct shared_allocator { int *ptr; private: // Recursive template to generate N-dimensional array type template struct variadic_array; template struct variadic_array { using type = typename variadic_array::type[first_dim]; }; template struct variadic_array { using type = A; }; template using variadic_array_t = typename variadic_array::type; template __device__ inline void align_ptr() { if constexpr (alignment > 0) { uint64_t p = reinterpret_cast(ptr); if(p % alignment != 0) { ptr = (int*)(p + (alignment-(p%alignment))); } } } public: /** * @brief Construct a new shared allocator using a pointer to extern shared memory. * @param[in] _ptr Pointer to the start of the extern shared memory. */ __device__ shared_allocator(int *_ptr): ptr(_ptr) {} /** * @brief Allocate shared memory for a single instance or N-dimensional array of type A. * @tparam A The type of the object to allocate. * @tparam dims... A list of dimensions for the N-dimensional array. * @return Reference to the allocated object. */ template __device__ inline variadic_array_t& allocate() { // static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation"); align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); ptr += sizeof(at)/sizeof(int); return *p; } /** * @brief Allocate shared memory for a single instance or N-dimensional array of type A. * @tparam alignment An alignment to enforce for this particular object. * @tparam A The type of the object to allocate. * @tparam dims... A list of dimensions for the N-dimensional array. * @return Reference to the allocated object. */ template __device__ inline variadic_array_t& allocate() { // static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation"); align_ptr(); using at = variadic_array_t; at*p = reinterpret_cast(ptr); ptr += sizeof(at)/sizeof(int); return *p; } }; } // namespace kittens