Support unpadded LSE layout (#970)
* Support unpadded LSE layout. Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com> Co-authored-by: Jianyu Huang <hjyahead@gmail.com> * Cleanup * Fix unpadded LSE on split-kv path * Fix formatting and comments * Fix inline vs forceinline --------- Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com> Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
This commit is contained in:
parent
320fb59487
commit
f816dee63c
@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
bool seqlenq_ngroups_swapped=false) {
|
||||
bool seqlenq_ngroups_swapped=false,
|
||||
const bool unpadded_lse=false) {
|
||||
|
||||
// Reset the parameters
|
||||
params = {};
|
||||
@ -135,6 +136,9 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
|
||||
#endif
|
||||
|
||||
params.unpadded_lse = unpadded_lse;
|
||||
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
|
||||
}
|
||||
|
||||
void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
@ -168,7 +172,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
bool deterministic) {
|
||||
bool deterministic,
|
||||
const bool unpadded_lse) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
||||
@ -181,7 +186,9 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
false, // seqlenq_ngroups_swapped
|
||||
unpadded_lse);
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.do_ptr = dout.data_ptr();
|
||||
@ -651,8 +658,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor p;
|
||||
// Only return softmax if there's dropout to reduce compilation time
|
||||
if (return_softmax) {
|
||||
@ -683,7 +689,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
seqlenq_ngroups_swapped);
|
||||
seqlenq_ngroups_swapped,
|
||||
/*unpadded_lse*/true);
|
||||
params.total_q = total_q;
|
||||
|
||||
if (paged_KV) {
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
@ -739,7 +747,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
|
||||
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
||||
@ -933,7 +941,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic);
|
||||
deterministic,
|
||||
/*unpadded_lse*/false);
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
@ -986,7 +995,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
|
||||
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
@ -1126,7 +1135,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
if (loop) {
|
||||
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
|
||||
@ -1137,6 +1146,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
|
||||
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
|
||||
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
|
||||
// Same holds for softmax_d, since LSE is stored in unpadded format.
|
||||
if (!deterministic) {
|
||||
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
} else {
|
||||
@ -1182,8 +1192,10 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic);
|
||||
deterministic,
|
||||
/*unpadded_lse*/true);
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
params.total_q = total_q;
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
void * __restrict__ softmax_lseaccum_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
@ -138,6 +138,9 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
void * __restrict__ alibi_slopes_ptr;
|
||||
index_t alibi_slopes_batch_stride;
|
||||
|
||||
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
|
||||
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -120,10 +120,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
|
||||
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
|
||||
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
|
||||
+ (m_block_max - 1) * kBlockM;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
+ (m_block_max - 1) * kBlockM;
|
||||
const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
|
||||
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
|
||||
const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
|
||||
@ -79,7 +79,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
|
||||
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
|
||||
const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
|
||||
|
||||
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
|
||||
@ -24,6 +24,27 @@ using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
|
||||
__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {
|
||||
// When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
|
||||
// Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
|
||||
// Otherwise, it's written as (h, b, seqlen_q).
|
||||
const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
|
||||
auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
|
||||
auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
|
||||
|
||||
auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);
|
||||
auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : (
|
||||
params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)
|
||||
);
|
||||
|
||||
auto lse_layout = make_layout(lse_shape, lse_stride);
|
||||
Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
|
||||
auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
|
||||
return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||
|
||||
@ -74,10 +95,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
|
||||
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
|
||||
make_shape(params.b, params.h, params.seqlen_q),
|
||||
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
|
||||
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
||||
|
||||
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
@ -424,10 +443,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
|
||||
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
|
||||
Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
|
||||
make_shape(params.b, params.h, params.seqlen_q),
|
||||
make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
|
||||
Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
||||
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
@ -986,7 +1002,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
|
||||
+ m_block * kBlockM) * params.d_rounded;
|
||||
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
|
||||
((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
|
||||
) + m_block * kBlockM;
|
||||
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
@ -1092,21 +1110,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
const int tidx = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
|
||||
const index_t lse_size = params.b * params.h * params.seqlen_q;
|
||||
|
||||
const index_t row_offset_lse = bidx * kBlockM;
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
|
||||
Shape<Int<kMaxSplits>, Int<kBlockM>>{},
|
||||
make_stride(params.b * params.h * params.seqlen_q, _1{}));
|
||||
make_stride(lse_size, _1{}));
|
||||
|
||||
// LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
|
||||
// This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
// This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
|
||||
Layout flat_layout = make_layout(lse_size);
|
||||
Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
|
||||
auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
|
||||
Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
|
||||
Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
|
||||
|
||||
Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
|
||||
|
||||
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
|
||||
|
||||
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
|
||||
// Read the LSE values from gmem and store them in shared memory, then transpose them.
|
||||
constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < kNLsePerThread; ++l) {
|
||||
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
|
||||
const int col = tidx % kBlockM;
|
||||
ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
|
||||
ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
|
||||
if (row < kMaxSplits) { sLSE[row][col] = lse; }
|
||||
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
|
||||
}
|
||||
@ -1145,7 +1178,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
|
||||
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
|
||||
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
|
||||
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; }
|
||||
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
|
||||
if (params.unpadded_lse) {
|
||||
const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
|
||||
if (lse_offset < lse_size) {
|
||||
gLSE_unpadded(lse_offset) = lse_logsum;
|
||||
}
|
||||
} else {
|
||||
gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
|
||||
}
|
||||
}
|
||||
// Store the scales exp(lse - lse_logsum) in shared memory.
|
||||
#pragma unroll
|
||||
for (int l = 0; l < kNLsePerThread; ++l) {
|
||||
|
||||
@ -130,7 +130,12 @@ def _flash_attn_backward(
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
||||
(
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
softmax_d,
|
||||
) = flash_attn_cuda.bwd(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
@ -178,7 +183,12 @@ def _flash_attn_varlen_backward(
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
||||
(
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
softmax_d,
|
||||
) = flash_attn_cuda.varlen_bwd(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
@ -883,7 +893,7 @@ def flash_attn_varlen_qkvpacked_func(
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
@ -968,7 +978,7 @@ def flash_attn_varlen_kvpacked_func(
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
@ -1056,7 +1066,7 @@ def flash_attn_varlen_func(
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
|
||||
@ -1151,6 +1151,7 @@ def test_flash_attn_varlen_output(
|
||||
assert nheads % nheads_k == 0
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
if kvpacked:
|
||||
kv = torch.randn(
|
||||
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
|
||||
@ -1918,9 +1919,11 @@ def test_flash_attn_kvcache(
|
||||
cache_seqlens = torch.randint(
|
||||
0 if new_kv else 1,
|
||||
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
||||
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
|
||||
if new_kv
|
||||
else (seqlen_k + 1),
|
||||
(
|
||||
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
|
||||
if new_kv
|
||||
else (seqlen_k + 1)
|
||||
),
|
||||
(batch_size,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user