Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)

This commit is contained in:
Tri Dao 2023-09-10 22:56:33 -07:00
parent 07005806ff
commit ee77b931b9
5 changed files with 29 additions and 14 deletions

View File

@ -992,15 +992,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
}
std::vector<at::Tensor>
mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_q x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_q x num_heads_k x head_size
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
const bool is_causal,
bool is_causal,
int num_splits
) {
@ -1032,8 +1032,8 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = kcache.size(1);
const int num_heads_k = kcache.size(2);
@ -1041,6 +1041,15 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1;
if (seqlenq_nheads_swapped) {
q = q.transpose(1, 2);
std::swap(seqlen_q, num_heads);
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
@ -1111,8 +1120,9 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads_k, head_size_og);
int seqlen_knew = k.size(1);
CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
if (head_size_og % 8 != 0) {
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
@ -1120,6 +1130,7 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
k_padded = k;
v_padded = v;
}
params.seqlen_knew = seqlen_knew;
params.knew_ptr = k_padded.data_ptr();
params.vnew_ptr = v_padded.data_ptr();
// All stride are in elements, not bytes.
@ -1175,6 +1186,10 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
}
}
if (seqlenq_nheads_swapped) {
out = out.transpose(1, 2);
softmax_lse = softmax_lse.transpose(1, 2);
}
return {out, softmax_lse};
}

View File

@ -19,7 +19,7 @@ struct BlockInfo {
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q))
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}

View File

@ -68,7 +68,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
// The scaling factors for the kernel.
float scale_softmax;

View File

@ -644,7 +644,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;

View File

@ -838,9 +838,9 @@ def flash_attn_with_kvcache(
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.