Transpose out when swapping seqlen_q and num_groups

This commit is contained in:
Tri Dao 2024-04-07 20:04:39 -07:00
parent f692b98d80
commit 9eb3d099c1

View File

@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
params.num_splits = num_splits;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
}
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);