Support unpadded LSE layout (#970)

* Support unpadded LSE layout.

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>

* Cleanup

* Fix unpadded LSE on split-kv path

* Fix formatting and comments

* Fix inline vs forceinline

---------

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
This commit is contained in:
Grigory Sizov 2024-06-27 11:38:13 +02:00 committed by GitHub
parent 320fb59487
commit f816dee63c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 108 additions and 38 deletions

View File

@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
bool seqlenq_ngroups_swapped=false) {
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
// Reset the parameters
params = {};
@ -135,6 +136,9 @@ void set_params_fprop(Flash_fwd_params &params,
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
#endif
params.unpadded_lse = unpadded_lse;
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
}
void set_params_dgrad(Flash_bwd_params &params,
@ -168,7 +172,8 @@ void set_params_dgrad(Flash_bwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
bool deterministic) {
bool deterministic,
const bool unpadded_lse) {
set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
@ -181,7 +186,9 @@ void set_params_dgrad(Flash_bwd_params &params,
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
false, // seqlenq_ngroups_swapped
unpadded_lse);
// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
@ -651,8 +658,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
@ -683,7 +689,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
softmax_scale,
window_size_left,
window_size_right,
seqlenq_ngroups_swapped);
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
params.total_q = total_q;
if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
@ -739,7 +747,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
@ -933,7 +941,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_scale,
window_size_left,
window_size_right,
deterministic);
deterministic,
/*unpadded_lse*/false);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd;
@ -986,7 +995,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@ -1126,7 +1135,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
if (loop) {
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
@ -1137,6 +1146,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
// Same holds for softmax_d, since LSE is stored in unpadded format.
if (!deterministic) {
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
@ -1182,8 +1192,10 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_scale,
window_size_left,
window_size_right,
deterministic);
deterministic,
/*unpadded_lse*/true);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
params.total_q = total_q;
auto launch = &run_mha_bwd;

View File

@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
// The scaling factors for the kernel.
float scale_softmax;
@ -138,6 +138,9 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -120,10 +120,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
+ (m_block_max - 1) * kBlockM;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
+ (m_block_max - 1) * kBlockM;
const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},

View File

@ -79,7 +79,8 @@ inline __device__ void compute_dot_do_o(const Params &params) {
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},

View File

@ -24,6 +24,27 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
__forceinline__ __device__ auto get_lse_tile(const Params &params, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {
// When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
// Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
// Otherwise, it's written as (h, b, seqlen_q).
const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);
auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : (
params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)
);
auto lse_layout = make_layout(lse_shape, lse_stride);
Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
@ -74,10 +95,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@ -424,10 +443,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
make_shape(params.b, params.h, params.seqlen_q),
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@ -986,7 +1002,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ m_block * kBlockM) * params.d_rounded;
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
) + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
@ -1092,21 +1110,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const index_t lse_size = params.b * params.h * params.seqlen_q;
const index_t row_offset_lse = bidx * kBlockM;
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
Shape<Int<kMaxSplits>, Int<kBlockM>>{},
make_stride(params.b * params.h * params.seqlen_q, _1{}));
make_stride(lse_size, _1{}));
// LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
// This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
// This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
Layout flat_layout = make_layout(lse_size);
Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
// Read the LSE values from gmem and store them in shared memory, then transpose them.
constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
const int col = tidx % kBlockM;
ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
if (row < kMaxSplits) { sLSE[row][col] = lse; }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
}
@ -1145,7 +1178,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; }
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
if (params.unpadded_lse) {
const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
if (lse_offset < lse_size) {
gLSE_unpadded(lse_offset) = lse_logsum;
}
} else {
gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
}
}
// Store the scales exp(lse - lse_logsum) in shared memory.
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {

View File

@ -130,7 +130,12 @@ def _flash_attn_backward(
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
(
dq,
dk,
dv,
softmax_d,
) = flash_attn_cuda.bwd(
dout,
q,
k,
@ -178,7 +183,12 @@ def _flash_attn_varlen_backward(
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
(
dq,
dk,
dv,
softmax_d,
) = flash_attn_cuda.varlen_bwd(
dout,
q,
k,
@ -883,7 +893,7 @@ def flash_attn_varlen_qkvpacked_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
@ -968,7 +978,7 @@ def flash_attn_varlen_kvpacked_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
@ -1056,7 +1066,7 @@ def flash_attn_varlen_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).

View File

@ -1151,6 +1151,7 @@ def test_flash_attn_varlen_output(
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
@ -1918,9 +1919,11 @@ def test_flash_attn_kvcache(
cache_seqlens = torch.randint(
0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1),
(
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1)
),
(batch_size,),
dtype=torch.int32,
device=device,