Refactor attention kernels (#53)

This commit is contained in:
Woosuk Kwon 2023-05-03 13:40:13 -07:00 committed by GitHub
parent 27f1410d06
commit 436e523bf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1253 additions and 2569 deletions

View File

@ -0,0 +1,5 @@
#pragma once
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"

View File

@ -0,0 +1,47 @@
#pragma once
#include <stdint.h>
namespace cacheflow {
// A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE>
struct Vec {};
// A vector type to store FP32 accumulators.
template<typename T>
struct FloatVec {};
// Template vector operations.
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template<typename T>
inline __device__ float sum(T v);
template<typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template<typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
template<typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
} // namespace cacheflow

View File

@ -0,0 +1,451 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_dtypes.cuh"
#include "attention_utils.cuh"
#include <algorithm>
#define WARP_SIZE 32
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace cacheflow {
// Utility function for attention softmax.
template<int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
// Grid: (num_heads, num_seqs).
template<
typename scalar_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
__global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y;
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or compute 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t);
float qk_max = -FLT_MAX;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int context_len = context_lens[seq_idx];
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= context_len;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
accs[i] += dot(logits_vec, v_vec);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
} // namespace cacheflow
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cacheflow::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
int BLOCK_SIZE,
int NUM_THREADS = 128>
void single_query_cached_kv_attention_launcher(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(T);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs);
dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
case 32:
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
break;
case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break;
case 80:
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
break;
case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break;
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
case 160:
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break;
case 192:
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
scale, \
block_tables, \
context_lens, \
max_context_len);
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 1: \
CALL_KERNEL_LAUNCHER(T, 1); \
break; \
case 2: \
CALL_KERNEL_LAUNCHER(T, 2); \
break; \
case 4: \
CALL_KERNEL_LAUNCHER(T, 4); \
break; \
case 8: \
CALL_KERNEL_LAUNCHER(T, 8); \
break; \
case 16: \
CALL_KERNEL_LAUNCHER(T, 16); \
break; \
case 32: \
CALL_KERNEL_LAUNCHER(T, 32); \
break; \
case 64: \
CALL_KERNEL_LAUNCHER(T, 64); \
break; \
case 128: \
CALL_KERNEL_LAUNCHER(T, 128); \
break; \
case 256: \
CALL_KERNEL_LAUNCHER(T, 256); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void single_query_cached_kv_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
// TODO(woosuk): Support FP32 and BF16.
if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
}
#undef WARP_SIZE
#undef MAX
#undef MIN

View File

@ -0,0 +1,38 @@
#pragma once
#include "attention_dtypes.cuh"
#include <float.h>
#include <type_traits>
namespace cacheflow {
// Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
template<typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template<typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
};
} // namespace cacheflow

View File

@ -0,0 +1,426 @@
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include <stdint.h>
namespace cacheflow {
// FP16 vector types for Q, K, V.
template<>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template<>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template<>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<uint16_t> {
using Type = float;
};
template<>
struct FloatVec<uint32_t> {
using Type = float2;
};
template<>
struct FloatVec<uint2> {
using Type = Float4_;
};
template<>
struct FloatVec<uint4> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
inline __device__ uint16_t float_to_half(float f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
return tmp.u16[0];
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
template<>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
template<>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template<>
inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
return c;
}
template<>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
template<>
inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
return c;
}
template<>
inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
template<>
inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template<>
inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
template<>
inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
return fc;
}
template<>
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
return fc;
}
template<>
inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
return fc;
}
template<>
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb + fc;
}
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
return fma(h0_h0(a), b, fc);
}
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
uint32_t s = h0_h0(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
uint32_t s = h0_h0(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template<>
inline __device__ float sum(uint16_t v) {
return half_to_float(v);
}
template<>
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
template<>
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
template<>
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
// Zero-out a vector.
inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}
// From float32 to float16.
inline __device__ void from_float(uint16_t& dst, float src) {
dst = float_to_half(src);
}
inline __device__ void from_float(uint32_t& dst, float2 src) {
dst = float2_to_half2(src);
}
inline __device__ void from_float(uint2& dst, Float4_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
}
inline __device__ void from_float(uint4& dst, Float8_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
// From float16 to float32.
inline __device__ float to_float(uint16_t u) {
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) {
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
inline __device__ Float8_ to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
} // namespace cacheflow

View File

@ -0,0 +1,250 @@
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
namespace cacheflow {
// Define FP32 vector data types.
struct Float4_ {
float2 x;
float2 y;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// FP32 vector types for Q, K, V.
template<>
struct Vec<float, 1> {
using Type = float;
};
template<>
struct Vec<float, 2> {
using Type = float2;
};
template<>
struct Vec<float, 4> {
using Type = float4;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<float> {
using Type = float;
};
template<>
struct FloatVec<float2> {
using Type = float2;
};
template<>
struct FloatVec<float4> {
using Type = float4;
};
// Vector addition.
inline __device__ float add(float a, float b) {
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
// Vector multiplication.
template<>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template<>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template<>
inline __device__ float2 mul(float a, float2 b) {
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
template<>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template<>
inline __device__ float4 mul(float a, float4 b) {
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
// Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) {
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
// Vector sum.
template<>
inline __device__ float sum(float v) {
return v;
}
template<>
inline __device__ float sum(float2 v) {
return v.x + v.y;
}
template<>
inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w;
}
template<>
inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y;
}
template<>
inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
// Vector dot product.
inline __device__ float dot(float a, float b) {
return a * b;
}
inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b);
return c.x + c.y;
}
inline __device__ float dot(Float4_ a, Float4_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
return acc.x + acc.y;
}
inline __device__ float dot(Float8_ a, Float8_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
acc = fma(a.z, b.z, acc);
acc = fma(a.w, b.w, acc);
return acc.x + acc.y;
}
// From float to float.
inline __device__ void from_float(float& dst, float src) {
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) {
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) {
dst = src;
}
// From float to float.
inline __device__ float to_float(float u) {
return u;
}
inline __device__ float2 to_float(float2 u) {
return u;
}
inline __device__ float4 to_float(float4 u) {
return u;
}
inline __device__ Float4_ to_float(Float4_ u) {
return u;
}
inline __device__ Float8_ to_float(Float8_ u) {
return u;
}
} // namespace cacheflow

View File

@ -1,896 +0,0 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_utils.h"
#include "cuda_primitives.h"
#include "reduction_utils.h"
#include <algorithm>
#define WARP_SIZE 32
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace cacheflow {
// Grid: (num_heads, num_seqs).
template<
typename scalar_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
__global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y;
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or compute 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 logits and accumulation.
float *logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t);
float qk_max = -FLT_MAX;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int context_len = context_lens[seq_idx];
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= context_len;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename FloatVec<V_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
accs[i] += dot(logits_vec, cast_to_float(v_vec));
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
convert_from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
} // namespace cacheflow
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cacheflow::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
int BLOCK_SIZE,
int NUM_THREADS = 128>
void single_query_cached_kv_attention_launcher(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs);
dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
case 32:
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
break;
case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break;
case 80:
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
break;
case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break;
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
case 160:
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break;
case 192:
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
default:
assert(false);
break;
}
}
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
scale, \
block_tables, \
context_lens, \
max_context_len);
void single_query_cached_kv_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
// TODO(woosuk): Support BF16.
if (query.element_size() == 2) {
// Half.
if (block_size == 1) {
CALL_KERNEL_LAUNCHER(uint16_t, 1);
} else if (block_size == 2) {
CALL_KERNEL_LAUNCHER(uint16_t, 2);
} else if (block_size == 4) {
CALL_KERNEL_LAUNCHER(uint16_t, 4);
} else if (block_size == 8) {
CALL_KERNEL_LAUNCHER(uint16_t, 8);
} else if (block_size == 16) {
CALL_KERNEL_LAUNCHER(uint16_t, 16);
} else if (block_size == 32) {
CALL_KERNEL_LAUNCHER(uint16_t, 32);
} else if (block_size == 64) {
CALL_KERNEL_LAUNCHER(uint16_t, 64);
} else if (block_size == 128) {
CALL_KERNEL_LAUNCHER(uint16_t, 128);
} else if (block_size == 256) {
CALL_KERNEL_LAUNCHER(uint16_t, 256);
} else {
assert(false);
}
} else {
// Float.
assert(false);
}
}
// namespace cacheflow {
// // Grid: (num_heads, num_query_tokens).
// template<
// typename scalar_t,
// int HEAD_SIZE,
// int BLOCK_SIZE,
// int NUM_THREADS>
// __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// const int seq_start_idx,
// const int seq_len,
// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
// const float scale,
// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
// const int context_len,
// const int max_num_blocks_per_seq,
// const int q_stride) {
// constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
// const int thread_idx = threadIdx.x;
// const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
// const int seq_idx = blockIdx.y;
// // A vector type to store a part of a key or a query.
// // The vector size is configured in such a way that the threads in a thread group
// // fetch or comput 16 bytes at a time.
// // For example, if the size of a thread group is 4 and the data type is half,
// // then the vector size is 16 / (4 * sizeof(half)) == 2.
// constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t));
// using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
// using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
// constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
// constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
// const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
// const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// // Load the query to registers.
// // Each thread in a thread group has a different part of the query.
// // For example, if the the thread group size is 4, then the first thread in the group
// // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// // th vectors of the query, and so on.
// // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
// const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
// Q_vec q_vecs[NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// // Memory planning.
// extern __shared__ char shared_mem[];
// // NOTE(woosuk): We use FP32 logits and accumulation.
// float *logits = reinterpret_cast<float*>(shared_mem);
// // Workspace for reduction.
// __shared__ float red_smem[2 * NUM_WARPS];
// // x == THREAD_GROUP_SIZE * VEC_SIZE
// // Each thread group fetches x elements from the key at a time.
// constexpr int x = 16 / sizeof(scalar_t);
// float qk_max = -FLT_MAX;
// const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx);
// // Iterate over the key blocks.
// // Each warp fetches a block of keys for each iteration.
// // Each thread group in a warp fetches a key from the block, and computes
// // dot product with the query.
// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
// const int physical_block_number = block_table[block_idx];
// const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// // Load a key to registers.
// // Each thread in a thread group has a different part of the key.
// // For example, if the the thread group size is 4, then the first thread in the group
// // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// // vectors of the key, and so on.
// K_vec k_vecs[NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
// const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
// + head_idx * HEAD_SIZE * BLOCK_SIZE
// + physical_block_offset * x;
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// const int offset1 = (vec_idx * VEC_SIZE) / x;
// const int offset2 = (vec_idx * VEC_SIZE) % x;
// k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// }
// // Compute dot product.
// // This includes a reduction across the threads in the same thread group.
// const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
// const bool mask = token_idx >= mask_boundary;
// if (thread_group_offset == 0) {
// // Store the partial reductions to shared memory.
// // NOTE(woosuk): It is required to zero out the masked logits.
// logits[token_idx] = mask ? 0.f : qk;
// // Update the max value.
// qk_max = mask ? qk_max : fmaxf(qk_max, qk);
// }
// }
// // Perform reduction across the threads in the same warp to get the
// // max qk value for each "warp" (not across the thread block yet).
// // The 0-th thread of each thread group already has its max qk value.
// #pragma unroll
// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
// }
// if (lane == 0) {
// red_smem[warp_idx] = qk_max;
// }
// __syncthreads();
// // TODO(woosuk): Refactor this part.
// // Get the max qk value for the sequence.
// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
// #pragma unroll
// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
// }
// // Broadcast the max qk value to all threads.
// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// // Get the sum of the exp values.
// float exp_sum = 0.f;
// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) {
// float val = __expf(logits[i] - qk_max);
// logits[i] = val;
// exp_sum += val;
// }
// exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// // Compute softmax.
// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
// for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
// logits[i] *= inv_sum;
// }
// __syncthreads();
// // Each thread will fetch 16 bytes from the value cache at a time.
// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t);
// using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
// using L_vec = typename FloatVec<V_vec>::Type;
// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
// float accs[NUM_ROWS_PER_THREAD];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// accs[i] = 0.f;
// }
// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
// const int physical_block_number = block_table[block_idx];
// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
// + head_idx * HEAD_SIZE * BLOCK_SIZE;
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE) {
// const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
// V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
// accs[i] += dot(logits_vec, cast_to_float(v_vec));
// }
// }
// }
// // Perform reduction within each warp.
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// float acc = accs[i];
// #pragma unroll
// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
// acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
// }
// accs[i] = acc;
// }
// // NOTE(woosuk): A barrier is required because the shared memory space for logits
// // is reused for the output.
// __syncthreads();
// // Perform reduction across warps.
// float* out_smem = reinterpret_cast<float*>(shared_mem);
// #pragma unroll
// for (int i = NUM_WARPS; i > 1; i /= 2) {
// int mid = i / 2;
// // Upper warps write to shared memory.
// if (warp_idx >= mid && warp_idx < i) {
// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// dst[row_idx] = accs[i];
// }
// }
// }
// __syncthreads();
// // Lower warps update the output.
// if (warp_idx < mid) {
// const float* src = &out_smem[warp_idx * HEAD_SIZE];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// accs[i] += src[row_idx];
// }
// }
// }
// __syncthreads();
// }
// // Write the final output.
// if (warp_idx == 0) {
// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// convert_from_float(*(out_ptr + row_idx), accs[i]);
// }
// }
// }
// }
// // Grid: (num_heads, num_query_tokens).
// template<
// typename scalar_t,
// int HEAD_SIZE,
// int BLOCK_SIZE,
// int NUM_THREADS>
// __global__ void multi_query_cached_kv_attention_kernel(
// const int* cu_query_lens, // [num_prompts+1]
// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx
// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
// const float scale,
// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
// const int* __restrict__ context_lens, // [num_prompts]
// const int max_num_blocks_per_seq,
// const int q_stride) {
// const int seq_idx = blockIdx.y;
// const int prompt_idx = seq_prompt_mapping[seq_idx];
// const int seq_start_idx = cu_query_lens[prompt_idx];
// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx;
// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq;
// const int context_len = context_lens[prompt_idx];
// multi_query_cached_kv_attention_kernel_unoptimized_<
// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
// out,
// q,
// seq_start_idx,
// seq_len,
// k_cache,
// v_cache,
// scale,
// block_table,
// context_len,
// max_num_blocks_per_seq,
// q_stride);
// }
// } // namespace cacheflow
// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
// cacheflow::multi_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
// <<<grid, block, shared_mem_size, stream>>>( \
// cu_query_lens_ptr, \
// seq_prompt_mapping_ptr, \
// out_ptr, \
// query_ptr, \
// key_cache_ptr, \
// value_cache_ptr, \
// scale, \
// block_tables_ptr, \
// context_lens_ptr, \
// max_num_blocks_per_seq, \
// query_stride);
// // TODO(woosuk): Tune NUM_THREADS.
// template<
// typename T,
// int BLOCK_SIZE,
// int NUM_THREADS = 128>
// void multi_query_cached_kv_attention_launcher(
// torch::Tensor& cu_query_lens,
// torch::Tensor& seq_prompt_mapping,
// torch::Tensor& out,
// torch::Tensor& query,
// torch::Tensor& key_cache,
// torch::Tensor& value_cache,
// float scale,
// torch::Tensor& block_tables,
// torch::Tensor& context_lens,
// int max_context_len) {
// int num_seqs = query.size(0);
// int num_heads = query.size(1);
// int head_size = query.size(2);
// int max_num_blocks_per_seq = block_tables.size(1);
// int query_stride = query.stride(0);
// int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
// T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
// T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
// T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
// T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
// int* block_tables_ptr = block_tables.data_ptr<int>();
// int* context_lens_ptr = context_lens.data_ptr<int>();
// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
// int logits_size = padded_max_context_len * sizeof(float);
// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// int shared_mem_size = std::max(logits_size, outputs_size);
// dim3 grid(num_heads, num_seqs);
// dim3 block(NUM_THREADS);
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// switch (head_size) {
// case 32:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
// case 64:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
// break;
// case 80:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
// break;
// case 96:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
// break;
// case 128:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
// break;
// case 160:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
// case 256:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break;
// default:
// assert(false);
// break;
// }
// }
// void multi_query_cached_kv_attention(
// torch::Tensor& cu_query_lens,
// torch::Tensor& out,
// torch::Tensor& query,
// torch::Tensor& key_cache,
// torch::Tensor& value_cache,
// float scale,
// torch::Tensor& block_tables,
// torch::Tensor& context_lens,
// int block_size,
// int max_context_len) {
// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU);
// int num_queries = query_lens.size(0) - 1;
// const int* query_lens_ptr = query_lens.data_ptr<int>();
// int num_seqs = query.size(0);
// torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32));
// auto accessor = cpu_tensor.accessor<int32_t, 1>();
// for (int i = 0, query_cursor = 0; i < num_seqs; ++i) {
// if (i >= query_lens_ptr[query_cursor + 1]) {
// ++query_cursor;
// }
// accessor[i] = query_cursor;
// }
// // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA)
// // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving
// // the mapping as an input parameter. Let's do this optimization in a later PR.
// torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA);
// // TODO(woosuk): Support BF16.
// if (query.element_size() == 2) {
// // Half.
// if (block_size == 8) {
// multi_query_cached_kv_attention_launcher<uint16_t, 8>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 16) {
// multi_query_cached_kv_attention_launcher<uint16_t, 16>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 32) {
// multi_query_cached_kv_attention_launcher<uint16_t, 32>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else {
// assert(false);
// }
// } else if (query.element_size() == 4) {
// // Float.
// if (block_size == 8) {
// multi_query_cached_kv_attention_launcher<float, 8>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 16) {
// multi_query_cached_kv_attention_launcher<float, 16>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 32) {
// multi_query_cached_kv_attention_launcher<float, 32>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else {
// assert(false);
// }
// } else {
// assert(false);
// }
// }
#undef WARP_SIZE
#undef MAX
#undef MIN

View File

@ -1,165 +0,0 @@
#pragma once
#include "cuda_primitives.h"
#include <float.h>
#include <type_traits>
#define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_FP32_ACUM_FOR_OUT
namespace cacheflow {
// A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE>
struct Vec {};
template<>
struct Vec<float, 1> {
using Type = float;
};
template<>
struct Vec<float, 2> {
using Type = float2;
};
template<>
struct Vec<float, 4> {
using Type = float4;
};
template<>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template<>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template<>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
template<typename T>
struct FloatVec {};
template<>
struct FloatVec<float> {
using Type = float;
};
template<>
struct FloatVec<float2> {
using Type = float2;
};
template<>
struct FloatVec<float4> {
using Type = float4;
};
template<>
struct FloatVec<uint16_t> {
using Type = float;
};
template<>
struct FloatVec<uint32_t> {
using Type = float2;
};
template<>
struct FloatVec<uint2> {
using Type = Float4_;
};
template<>
struct FloatVec<uint4> {
using Type = Float8_;
};
template<int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
{
using K_vec_acum = typename FloatVec<K_vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct Qk_dot {
template<typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
{
return qk_dot_<THREADS_PER_KEY>(q, k);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
{
float4 c;
float zero = 0.f;
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using K_vec_acum = typename FloatVec<uint32_t>::Type;
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Qk_dot<uint16_t, 4> {
template<int N>
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return qk_hmma_dot_(q, k);
#else
return qk_dot_<4>(q, k);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
} // namespace cacheflow
#undef MMHA_USE_FP32_ACUM_FOR_FMA
#undef MMHA_USE_FP32_ACUM_FOR_OUT

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "reduction_utils.h"
#include "reduction_utils.cuh"
namespace cacheflow {

34
csrc/reduction_utils.cuh Normal file
View File

@ -0,0 +1,34 @@
#pragma once
namespace cacheflow {
template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
} // namespace cacheflow

View File

@ -1,76 +0,0 @@
#pragma once
namespace cacheflow {
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
#define FINAL_MASK 0xffffffff
template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
} // namespace cacheflow

View File

@ -18,7 +18,7 @@ ext_modules.append(cache_extension)
# Attention kernels.
attention_extension = cpp_extension.CUDAExtension(
name='cacheflow.attention_ops',
sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'],
sources=['csrc/attention.cpp', 'csrc/attention/attention_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(attention_extension)

View File

@ -271,78 +271,6 @@ def test_multi_query_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def test_multi_query_cached_kv_attention(
num_queries: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
query_lens = random.sample(range(1, MAX_SEQ_LEN), num_queries)
cu_query_lens = [0]
for query_len in query_lens:
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
context_lens = [
query_len + random.randint(0, MAX_SEQ_LEN - query_len)
for query_len in query_lens
]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_queries):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
attention_ops.multi_query_cached_kv_attention(
cu_query_lens,
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
)
ref_output = ref_multi_query_cached_kv_attention(
cu_query_lens,
query,
key_cache,
value_cache,
block_tables,
context_lens,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def test_attention(seed: int) -> None:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
@ -364,24 +292,6 @@ def test_attention(seed: int) -> None:
dtype=dtype,
)
# NOTE(siyuan): Same as above. Re-run the test if it fails. Also
# note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance.
for dtype in [torch.half, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing multi_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
test_multi_query_cached_kv_attention(
num_queries=11,
num_heads=3,
head_size=head_size,
block_size=block_size,
num_blocks=1024,
dtype=dtype,
)
# NOTE(woosuk): FlashAttention does not support FP32.
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.