From 3250ff3d829afceea80ecf7299bef629c93a6b47 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 18 Sep 2023 14:52:16 -0700 Subject: [PATCH] Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H) --- csrc/flash_attn/flash_api.cpp | 23 +++++++++++++++++++---- tests/test_flash_attn.py | 1 + 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 140ba02..8b4df5b 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -235,13 +235,13 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n } std::vector -mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float p_dropout, const float softmax_scale, - const bool is_causal, + bool is_causal, const bool return_softmax, c10::optional gen_) { @@ -271,8 +271,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head const auto sizes = q.sizes(); const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - const int num_heads = sizes[2]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); @@ -280,6 +280,15 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + + // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case + const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1 and p_dropout == 0.f and head_size_og % 8 == 0; + if (seqlenq_nheads_swapped) { + q = q.transpose(1, 2); + std::swap(seqlen_q, num_heads); + } + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); @@ -388,6 +397,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head if (out_.has_value()) { out_.value().copy_(out); } } + if (seqlenq_nheads_swapped) { + out = out.transpose(1, 2); + out_padded = out_padded.transpose(1, 2); + q_padded = q_padded.transpose(1, 2); + softmax_lse = softmax_lse.transpose(1, 2); + } return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index e4c9843..1765185 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -908,6 +908,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ + (1, 147), (113, 203), (128, 217), (113, 211),