From 1c9717d699c720ce62b662b376ce224988609fbd Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Thu, 19 Sep 2024 22:00:41 -0700 Subject: [PATCH] address comments --- hopper/flash_bwd_kernel.h | 2 +- hopper/flash_bwd_launch_template.h | 7 ++++--- hopper/flash_fwd_launch_template.h | 7 ++++--- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/hopper/flash_bwd_kernel.h b/hopper/flash_bwd_kernel.h index ccd53f3..9eba8fb 100644 --- a/hopper/flash_bwd_kernel.h +++ b/hopper/flash_bwd_kernel.h @@ -253,7 +253,7 @@ public: } if constexpr (Is_causal) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); + int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); if (m_block_min >= m_block_max) { continue; } } collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 2fe1655..2683e72 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -23,6 +23,7 @@ using namespace cute; template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); using TileShape_MK = cute::Shape, Int>; using ElementAccum = float; using PreprocessKernel = flash::FlashAttnBwdPreprocess; @@ -174,7 +175,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_local, Is_local, [&] { BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + run_flash_bwd(params, stream); }); }); }); @@ -188,7 +189,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_local, Is_local, [&] { BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + run_flash_bwd(params, stream); }); }); }); @@ -202,7 +203,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_local, Is_local, [&] { BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + run_flash_bwd(params, stream); }); }); }); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 2ed0521..0c0790e 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -20,6 +20,7 @@ template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); using Element = typename Kernel_traits::Element; using OutputType = typename Kernel_traits::OutputType; using TileShape_MNK = typename Kernel_traits::TileShape_MNK; @@ -121,7 +122,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { run_flash_fwd< Flash_fwd_kernel_traits, - Is_causal, Is_local, Seqlen_traits + Is_causal, Is_local && !Is_causal, Seqlen_traits >(params, stream); }); }); @@ -138,7 +139,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { run_flash_fwd< Flash_fwd_kernel_traits, - Is_causal, Is_local, Seqlen_traits + Is_causal, Is_local && !Is_causal, Seqlen_traits >(params, stream); }); }); @@ -156,7 +157,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { run_flash_fwd< Flash_fwd_kernel_traits, - Is_causal, Is_local, Seqlen_traits + Is_causal, Is_local && !Is_causal, Seqlen_traits >(params, stream); }); }); diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 334cdca..d999536 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -613,7 +613,7 @@ struct CollectiveMainloopBwd { } cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to } - if constexpr (Deterministic) { + if constexpr (Is_local && Deterministic) { constexpr int kBlockM = get<0>(TileShape_MNK{}); int const seqlen_q = get_seqlen_q(params, bidb); int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM);