From ee77b931b91476a41e4b1c19cf76281444044908 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 10 Sep 2023 22:56:33 -0700 Subject: [PATCH] Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza) --- csrc/flash_attn/flash_api.cpp | 31 +++++++++++++++++++------- csrc/flash_attn/src/block_info.h | 2 +- csrc/flash_attn/src/flash.h | 2 +- csrc/flash_attn/src/flash_fwd_kernel.h | 2 +- flash_attn/flash_attn_interface.py | 6 ++--- 5 files changed, 29 insertions(+), 14 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 430c3a6..ff86d01 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -992,15 +992,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } std::vector -mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size +mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &k_, // batch_size x seqlen_q x num_heads_k x head_size - c10::optional &v_, // batch_size x seqlen_q x num_heads_k x head_size + c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &seqlens_k_, // batch_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, - const bool is_causal, + bool is_causal, int num_splits ) { @@ -1032,8 +1032,8 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x const auto sizes = q.sizes(); const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - const int num_heads = sizes[2]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int seqlen_k = kcache.size(1); const int num_heads_k = kcache.size(2); @@ -1041,6 +1041,15 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + + // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case + const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1; + if (seqlenq_nheads_swapped) { + q = q.transpose(1, 2); + std::swap(seqlen_q, num_heads); + } + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og); @@ -1111,8 +1120,9 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device"); TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); - CHECK_SHAPE(k, batch_size, seqlen_q, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_q, num_heads_k, head_size_og); + int seqlen_knew = k.size(1); + CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og); if (head_size_og % 8 != 0) { k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); @@ -1120,6 +1130,7 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x k_padded = k; v_padded = v; } + params.seqlen_knew = seqlen_knew; params.knew_ptr = k_padded.data_ptr(); params.vnew_ptr = v_padded.data_ptr(); // All stride are in elements, not bytes. @@ -1175,6 +1186,10 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x } } + if (seqlenq_nheads_swapped) { + out = out.transpose(1, 2); + softmax_lse = softmax_lse.transpose(1, 2); + } return {out, softmax_lse}; } diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index 7e60280..18793e9 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -19,7 +19,7 @@ struct BlockInfo { // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)) + , actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 946d0d5..ff49cb8 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -68,7 +68,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; // The scaling factors for the kernel. float scale_softmax; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index dcd0814..c0e3df5 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -644,7 +644,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const BlockInfo binfo(params, bidb); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } - // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index ae49728..d3c1305 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -838,9 +838,9 @@ def flash_attn_with_kvcache( q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size, seqlen_cache, nheads_k, headdim) v_cache: (batch_size, seqlen_cache, nheads_k, headdim) - k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with - k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k. + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. softmax_scale: float. The scaling of QK^T before applying softmax.