From db2f80692cbe2b4e8fcf149b5b45e863f6b70c58 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Nov 2023 22:20:01 -0800 Subject: [PATCH] Write zero to out / grad if seqlen_q or seqlen_k is zero --- csrc/flash_attn/flash_api.cpp | 19 +++++- csrc/flash_attn/src/flash_bwd_kernel.h | 5 +- csrc/flash_attn/src/flash_fwd_kernel.h | 88 +++++++++++++------------- 3 files changed, 63 insertions(+), 49 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index bf8cdcb..c40bbc1 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -405,8 +405,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size params.philox_args = gen->philox_cuda_state(counter_offset); } - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } at::Tensor out_padded = out; if (head_size_og % 8 != 0) { @@ -794,7 +800,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si params.rng_state[1] = std::get<1>(seeds); } - launch(params, stream, /*configure=*/false); + if (seqlen_q > 0) { + launch(params, stream, /*configure=*/false); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 69dde7e..825e2e3 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -444,7 +444,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return; + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); if (Is_local) { @@ -672,7 +672,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // We might need to exit early and write 0 to dK and dV for those blocks. // Otherwise we get wrong result for the case where we don't enter the for loop. // And we might read OOB elements from gQ and gdO. - if (Is_local && m_block < m_block_min) { + // This also covers the case where actual_seqlen_q == 0 + if ((Is_local || !Is_even_MN) && m_block < m_block_min) { const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 323068e..42e1ef7 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -91,7 +91,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); @@ -101,50 +101,50 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } - // We exit early and write 0 to gO and gLSE. - // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= n_block_min) { - // Save seed and offset for backward. If we don't have this here, the 0-th thread block might - // exit early and no one saves the rng state. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); - } - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_tensor(shape(tOgO)); - clear(tOrO); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOgO); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } - } - return; + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); } + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; } // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }