#pragma once #include namespace cacheflow { // A vector type to store Q, K, V elements. template struct Vec {}; // A vector type to store FP32 accumulators. template struct FloatVec {}; // Template vector operations. template inline __device__ Acc mul(A a, B b); template inline __device__ float sum(T v); template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { T raw; uint32_t words[WORDS]; } tmp; #pragma unroll for (int ii = 0; ii < WORDS; ++ii) { tmp.words[ii] = 0u; } dst = tmp.raw; } } // namespace cacheflow