Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
This commit is contained in:
parent
07005806ff
commit
ee77b931b9
@ -992,15 +992,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_q x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_q x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
bool is_causal,
|
||||
int num_splits
|
||||
) {
|
||||
|
||||
@ -1032,8 +1032,8 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q = sizes[1];
|
||||
const int num_heads = sizes[2];
|
||||
int seqlen_q = sizes[1];
|
||||
int num_heads = sizes[2];
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = kcache.size(1);
|
||||
const int num_heads_k = kcache.size(2);
|
||||
@ -1041,6 +1041,15 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
|
||||
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
|
||||
const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1;
|
||||
if (seqlenq_nheads_swapped) {
|
||||
q = q.transpose(1, 2);
|
||||
std::swap(seqlen_q, num_heads);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
@ -1111,8 +1120,9 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads_k, head_size_og);
|
||||
int seqlen_knew = k.size(1);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
|
||||
if (head_size_og % 8 != 0) {
|
||||
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
@ -1120,6 +1130,7 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
}
|
||||
params.seqlen_knew = seqlen_knew;
|
||||
params.knew_ptr = k_padded.data_ptr();
|
||||
params.vnew_ptr = v_padded.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
@ -1175,6 +1186,10 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
}
|
||||
}
|
||||
|
||||
if (seqlenq_nheads_swapped) {
|
||||
out = out.transpose(1, 2);
|
||||
softmax_lse = softmax_lse.transpose(1, 2);
|
||||
}
|
||||
return {out, softmax_lse};
|
||||
}
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ struct BlockInfo {
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q))
|
||||
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
void * __restrict__ softmax_lseaccum_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
|
||||
@ -644,7 +644,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); }
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
|
||||
|
||||
@ -838,9 +838,9 @@ def flash_attn_with_kvcache(
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
|
||||
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
|
||||
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
|
||||
k_cache, starting at the indices specified by cache_seqlens.
|
||||
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
|
||||
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
||||
k with k_cache, starting at the indices specified by cache_seqlens.
|
||||
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
|
||||
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||
KV cache.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user