diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index ac753af..3c936ff 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, - bool seqlenq_ngroups_swapped=false) { + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false) { // Reset the parameters params = {}; @@ -135,6 +136,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, #ifdef FLASHATTENTION_DISABLE_UNEVEN_K TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif + + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -168,7 +172,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, - bool deterministic) { + bool deterministic, + const bool unpadded_lse) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -181,7 +186,9 @@ void set_params_dgrad(Flash_bwd_params ¶ms, p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + false, // seqlenq_ngroups_swapped + unpadded_lse); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); @@ -651,8 +658,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - - auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { @@ -683,7 +689,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s softmax_scale, window_size_left, window_size_right, - seqlenq_ngroups_swapped); + seqlenq_ngroups_swapped, + /*unpadded_lse*/true); + params.total_q = total_q; if (paged_KV) { params.block_table = block_table.data_ptr(); @@ -739,7 +747,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s out = out.reshape(size_before).transpose(1, 2).reshape(size_after); out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1}); + softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; @@ -933,7 +941,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_scale, window_size_left, window_size_right, - deterministic); + deterministic, + /*unpadded_lse*/false); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; @@ -986,7 +995,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i @@ -1126,7 +1135,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (loop) { // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) @@ -1137,6 +1146,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // allowed to do. So we won't have to do any bound checking, and performance should stay the same. + // Same holds for softmax_d, since LSE is stored in unpadded format. if (!deterministic) { dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } else { @@ -1182,8 +1192,10 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_scale, window_size_left, window_size_right, - deterministic); + deterministic, + /*unpadded_lse*/true); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + params.total_q = total_q; auto launch = &run_mha_bwd; diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 88a7195..49384b6 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; // The scaling factors for the kernel. float scale_softmax; @@ -138,6 +138,9 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ alibi_slopes_ptr; index_t alibi_slopes_batch_stride; + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 7d35209..96aed04 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -120,10 +120,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q - + (m_block_max - 1) * kBlockM; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded - + (m_block_max - 1) * kBlockM; + const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM; + // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d + const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index aa06415..c8e3074 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -79,7 +79,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d + const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index fd68cec..b448f1f 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -24,6 +24,27 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo &binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( + params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) + ); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + + template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { @@ -74,10 +95,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -424,10 +443,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -986,7 +1002,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? + ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) + ) + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, @@ -1092,21 +1110,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { const int tidx = threadIdx.x; const int bidx = blockIdx.x; + const index_t lse_size = params.b * params.h * params.seqlen_q; + const index_t row_offset_lse = bidx * kBlockM; Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), Shape, Int>{}, - make_stride(params.b * params.h * params.seqlen_q, _1{})); + make_stride(lse_size, _1{})); + + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); + + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; - // Read the LSE values from gmem and store them in shared memory, then tranpose them. + // Read the LSE values from gmem and store them in shared memory, then transpose them. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; if (row < kMaxSplits) { sLSE[row][col] = lse; } // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } } @@ -1145,7 +1178,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } + } // Store the scales exp(lse - lse_logsum) in shared memory. #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a7f15be..9061372 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -130,7 +130,12 @@ def _flash_attn_backward( maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.bwd( dout, q, k, @@ -178,7 +183,12 @@ def _flash_attn_varlen_backward( maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.varlen_bwd( dout, q, k, @@ -883,7 +893,7 @@ def flash_attn_varlen_qkvpacked_func( (they might not have the right scaling). Return: out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). @@ -968,7 +978,7 @@ def flash_attn_varlen_kvpacked_func( (they might not have the right scaling). Return: out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). @@ -1056,7 +1066,7 @@ def flash_attn_varlen_func( (they might not have the right scaling). Return: out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 308e30b..9468452 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1151,6 +1151,7 @@ def test_flash_attn_varlen_output( assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True @@ -1918,9 +1919,11 @@ def test_flash_attn_kvcache( cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1), + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), (batch_size,), dtype=torch.int32, device=device,