#pragma once #include "attention_dtypes.h" #include #include namespace cacheflow { // Q*K^T operation. template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). A_vec qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { qk_vec = fma(q[ii], k[ii], qk_vec); } // Finalize the reduction across lanes. float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { qk += __shfl_xor_sync(uint32_t(-1), qk, mask); } return qk; } template struct Qk_dot { template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; } // namespace cacheflow