From 2d8ea9a5303b7de8865279b66c8f7e8ed2a59aee Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 20 Sep 2023 23:38:22 -0700 Subject: [PATCH] Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza) --- csrc/flash_attn/flash_api.cpp | 55 ++++++++++--------- .../src/flash_bwd_launch_template.h | 2 - csrc/flash_attn/src/flash_fwd_kernel.h | 30 +++++----- .../src/flash_fwd_launch_template.h | 24 ++++---- tests/test_flash_attn.py | 4 +- 5 files changed, 62 insertions(+), 53 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 8b4df5b..d62bd0d 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -282,11 +282,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case - const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1 and p_dropout == 0.f and head_size_og % 8 == 0; - if (seqlenq_nheads_swapped) { - q = q.transpose(1, 2); - std::swap(seqlen_q, num_heads); + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size_og % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); @@ -353,9 +356,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size is_causal); // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm90 || is_sm8x - ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)); + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. @@ -369,6 +370,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); } // number of times random will be generated per thread, to offset philox counter in thc random @@ -397,11 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (out_.has_value()) { out_.value().copy_(out); } } - if (seqlenq_nheads_swapped) { - out = out.transpose(1, 2); - out_padded = out_padded.transpose(1, 2); - q_padded = q_padded.transpose(1, 2); - softmax_lse = softmax_lse.transpose(1, 2); + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } @@ -1050,11 +1052,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case - const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1; - if (seqlenq_nheads_swapped) { - q = q.transpose(1, 2); - std::swap(seqlen_q, num_heads); + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size_og % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); @@ -1184,12 +1189,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he params.rotary_dim = 0; } - // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm90 || is_sm8x - ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (params.knew_ptr == nullptr ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -1197,6 +1199,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (num_splits < 1) { params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); if (params.num_splits > 1) { at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); @@ -1219,9 +1222,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } } - if (seqlenq_nheads_swapped) { - out = out.transpose(1, 2); - softmax_lse = softmax_lse.transpose(1, 2); + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse}; } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index f4f2388..fa45398 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -123,14 +123,12 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, kernel_dkv<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } -// template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { if (configure) return; run_flash_bwd_seqk_parallel(params, stream, configure); } -// template void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index d3736be..68d6134 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1141,19 +1141,18 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kBlockM = 16; constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); - static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); - static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. @@ -1169,17 +1168,17 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { make_stride(params.b * params.h * params.seqlen_q, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; if (row < kMaxSplits) { sLSE[row][col] = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } } // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } __syncthreads(); @@ -1187,7 +1186,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // 16 rows, so each time we load we can load 8 rows). + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; // static_assert(kThreadsPerSplit <= 32); static_assert(kRowsPerLoadTranspose <= 32); @@ -1230,7 +1229,13 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); - typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); @@ -1247,7 +1252,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } // Load Oaccum in then scale and accumulate to O - #pragma unroll 2 for (int split = 0; split < params.num_splits; ++split) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM @@ -1263,11 +1267,11 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); } } - // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } } tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; } - // if (cute::thread0()) { print(tOrO); } + // if (cute::thread0()) { print_tensor(tOrO); } Tensor rO = flash::convert_type(tOrO); // Write to gO diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 9c8c750..51d7576 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -20,10 +20,10 @@ __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { flash::compute_attn_splitkv(params); } -template +template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + flash::combine_attn_seqk_parallel(params); } template @@ -93,22 +93,26 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); }); if (params.num_splits > 1) { - dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 1765185..d37c5c7 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1505,12 +1505,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): @pytest.mark.parametrize("rotary_interleaved", [False, True]) # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -# @pytest.mark.parametrize("rotary_fraction", [1.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [