Fix performance regression with causal
This commit is contained in:
parent
2c3baba4a6
commit
b4bf9cc1f3
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
||||
from src.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
||||
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
# # from flash_attn.triton.fused_attention import attention as attention
|
||||
# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
|
||||
|
||||
@ -60,15 +60,15 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
@ -104,11 +104,11 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
|
||||
@ -43,14 +43,16 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
@ -79,13 +81,13 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user