From a3a257c71de8600a9388cd8c35a96f959d1d2f64 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Fri, 16 Aug 2024 01:13:35 -0700 Subject: [PATCH] Fix out-of-bound writes for var-seq-len zero-length KVs --- hopper/epilogue_fwd_sm90_tma.hpp | 4 ++-- hopper/flash_fwd_kernel.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index 5133c55..993f2e2 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -285,10 +285,10 @@ struct CollectiveEpilogueFwd { for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM ); static_assert(kBlockM <= NumMmaThreads); - if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; } + if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; } } }; diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index 6b55021..f2041a4 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, } int n_block_max = collective_mainloop.get_n_block_max( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if (Is_causal && n_block_max <= 0) { + if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); continue; @@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, } int n_block_max = collective_mainloop.get_n_block_max( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); continue; }