Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H)

This commit is contained in:
Tri Dao 2023-09-18 14:52:16 -07:00
parent 43617deab9
commit 3250ff3d82
2 changed files with 20 additions and 4 deletions

View File

@ -235,13 +235,13 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
}
std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float p_dropout,
const float softmax_scale,
const bool is_causal,
bool is_causal,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
@ -271,8 +271,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
@ -280,6 +280,15 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1 and p_dropout == 0.f and head_size_og % 8 == 0;
if (seqlenq_nheads_swapped) {
q = q.transpose(1, 2);
std::swap(seqlen_q, num_heads);
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
@ -388,6 +397,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
if (out_.has_value()) { out_.value().copy_(out); }
}
if (seqlenq_nheads_swapped) {
out = out.transpose(1, 2);
out_padded = out_padded.transpose(1, 2);
q_padded = q_padded.transpose(1, 2);
softmax_lse = softmax_lse.transpose(1, 2);
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}

View File

@ -908,6 +908,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 147),
(113, 203),
(128, 217),
(113, 211),