Fix out-of-bound writes for var-seq-len zero-length KVs

This commit is contained in:
Ying Zhang 2024-08-16 01:13:35 -07:00
parent bcd918f275
commit a3a257c71d
2 changed files with 4 additions and 4 deletions

View File

@ -285,10 +285,10 @@ struct CollectiveEpilogueFwd {
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM
);
static_assert(kBlockM <= NumMmaThreads);
if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
}
};

View File

@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if (Is_causal && n_block_max <= 0) {
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
continue;
@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
}
int n_block_max = collective_mainloop.get_n_block_max(
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
continue;
}