Write zero to out / grad if seqlen_q or seqlen_k is zero

This commit is contained in:
Tri Dao 2023-11-19 22:20:01 -08:00
parent 43bb6d8aaa
commit db2f80692c
3 changed files with 63 additions and 49 deletions

View File

@ -405,8 +405,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
} }
if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out; at::Tensor out_padded = out;
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
@ -794,7 +800,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
params.rng_state[1] = std::get<1>(seeds); params.rng_state[1] = std::get<1>(seeds);
} }
if (seqlen_q > 0) {
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk.zero_();
dv.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups // For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) { if (num_heads_k != num_heads) {

View File

@ -444,7 +444,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return; if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if (Is_local) { if (Is_local) {
@ -672,7 +672,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// We might need to exit early and write 0 to dK and dV for those blocks. // We might need to exit early and write 0 to dK and dV for those blocks.
// Otherwise we get wrong result for the case where we don't enter the for loop. // Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO. // And we might read OOB elements from gQ and gdO.
if (Is_local && m_block < m_block_min) { // This also covers the case where actual_seqlen_q == 0
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)

View File

@ -91,7 +91,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
@ -101,9 +101,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// } // }
// We exit early and write 0 to gO and gLSE. }
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV. // Otherwise we might read OOB elements from gK and gV.
if (n_block_max <= n_block_min) { if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state. // exit early and no one saves the rng state.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
@ -145,7 +146,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
return; return;
} }
}
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one // We iterate over the blocks in reverse order. This is because the last block is the only one