address comments

This commit is contained in:
Ying Zhang 2024-09-19 22:00:41 -07:00
parent be6c1b98c4
commit 1c9717d699
4 changed files with 10 additions and 8 deletions

View File

@ -253,7 +253,7 @@ public:
}
if constexpr (Is_causal) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
if (m_block_min >= m_block_max) { continue; }
}
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);

View File

@ -23,6 +23,7 @@ using namespace cute;
template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Is_local, bool Varlen, bool Deterministic,
bool dKV_swapAB, bool dQ_swapAB, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
using ElementAccum = float;
using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
@ -174,7 +175,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_local, Is_local, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 128, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
run_flash_bwd<Headdim, 128, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
});
});
});
@ -188,7 +189,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_local, Is_local, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
});
});
});
@ -202,7 +203,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_local, Is_local, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
});
});
});

View File

@ -20,6 +20,7 @@
template<typename Kernel_traits, bool Is_causal, bool Is_local, typename Seqlen_traits>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
using Element = typename Kernel_traits::Element;
using OutputType = typename Kernel_traits::OutputType;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
@ -121,7 +122,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
Is_causal, Is_local, Seqlen_traits
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
});
});
@ -138,7 +139,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, (Is_causal || Is_local) ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Is_local, Seqlen_traits
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
});
});
@ -156,7 +157,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Is_local, Seqlen_traits
Is_causal, Is_local && !Is_causal, Seqlen_traits
>(params, stream);
});
});

View File

@ -613,7 +613,7 @@ struct CollectiveMainloopBwd {
}
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
}
if constexpr (Deterministic) {
if constexpr (Is_local && Deterministic) {
constexpr int kBlockM = get<0>(TileShape_MNK{});
int const seqlen_q = get_seqlen_q(params, bidb);
int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM);