address comments
This commit is contained in:
parent
be6c1b98c4
commit
1c9717d699
@ -253,7 +253,7 @@ public:
|
||||
}
|
||||
if constexpr (Is_causal) {
|
||||
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
|
||||
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
|
||||
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
|
||||
if (m_block_min >= m_block_max) { continue; }
|
||||
}
|
||||
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
|
||||
|
||||
@ -23,6 +23,7 @@ using namespace cute;
|
||||
template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Is_local, bool Varlen, bool Deterministic,
|
||||
bool dKV_swapAB, bool dQ_swapAB, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
|
||||
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
|
||||
using ElementAccum = float;
|
||||
using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
|
||||
@ -174,7 +175,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.is_local, Is_local, [&] {
|
||||
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
|
||||
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
|
||||
run_flash_bwd<Headdim, 128, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
|
||||
run_flash_bwd<Headdim, 128, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -188,7 +189,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.is_local, Is_local, [&] {
|
||||
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
|
||||
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
|
||||
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
|
||||
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -202,7 +203,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.is_local, Is_local, [&] {
|
||||
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
|
||||
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
|
||||
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
|
||||
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local && !Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_local, typename Seqlen_traits>
|
||||
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using OutputType = typename Kernel_traits::OutputType;
|
||||
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
||||
@ -121,7 +122,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
||||
run_flash_fwd<
|
||||
Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
|
||||
Is_causal, Is_local, Seqlen_traits
|
||||
Is_causal, Is_local && !Is_causal, Seqlen_traits
|
||||
>(params, stream);
|
||||
});
|
||||
});
|
||||
@ -138,7 +139,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
||||
run_flash_fwd<
|
||||
Flash_fwd_kernel_traits<Headdim, 128, (Is_causal || Is_local) ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
|
||||
Is_causal, Is_local, Seqlen_traits
|
||||
Is_causal, Is_local && !Is_causal, Seqlen_traits
|
||||
>(params, stream);
|
||||
});
|
||||
});
|
||||
@ -156,7 +157,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
||||
run_flash_fwd<
|
||||
Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
|
||||
Is_causal, Is_local, Seqlen_traits
|
||||
Is_causal, Is_local && !Is_causal, Seqlen_traits
|
||||
>(params, stream);
|
||||
});
|
||||
});
|
||||
|
||||
@ -613,7 +613,7 @@ struct CollectiveMainloopBwd {
|
||||
}
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
|
||||
}
|
||||
if constexpr (Deterministic) {
|
||||
if constexpr (Is_local && Deterministic) {
|
||||
constexpr int kBlockM = get<0>(TileShape_MNK{});
|
||||
int const seqlen_q = get_seqlen_q(params, bidb);
|
||||
int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user