Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H)
This commit is contained in:
parent
43617deab9
commit
3250ff3d82
@ -235,13 +235,13 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
@ -271,8 +271,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q = sizes[1];
|
||||
const int num_heads = sizes[2];
|
||||
int seqlen_q = sizes[1];
|
||||
int num_heads = sizes[2];
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
@ -280,6 +280,15 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
|
||||
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
|
||||
const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1 and p_dropout == 0.f and head_size_og % 8 == 0;
|
||||
if (seqlenq_nheads_swapped) {
|
||||
q = q.transpose(1, 2);
|
||||
std::swap(seqlen_q, num_heads);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
@ -388,6 +397,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
if (out_.has_value()) { out_.value().copy_(out); }
|
||||
}
|
||||
|
||||
if (seqlenq_nheads_swapped) {
|
||||
out = out.transpose(1, 2);
|
||||
out_padded = out_padded.transpose(1, 2);
|
||||
q_padded = q_padded.transpose(1, 2);
|
||||
softmax_lse = softmax_lse.transpose(1, 2);
|
||||
}
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
||||
}
|
||||
|
||||
|
||||
@ -908,6 +908,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(1, 147),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user