From 9eb3d099c1fb806b4f6351a7b18f2e7731bcec64 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 7 Apr 2024 20:04:39 -0700 Subject: [PATCH] Transpose out when swapping seqlen_q and num_groups --- csrc/flash_attn/flash_api.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 001acac..ac753af 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, params.num_splits = num_splits; if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128); } if (params.num_splits > 1) { at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); @@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { - const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); seqlen_q = ngroups; num_heads = num_heads_k; @@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { out = torch::empty_like(q_padded); @@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { - const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); max_seqlen_q = ngroups; num_heads = num_heads_k; @@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, total_q, num_heads, head_size_og); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); + } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { out = torch::empty_like(q_padded);