From bc2c210254dbbe8814498449a562bc1996c5bf7d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 11 Jul 2022 10:28:46 -0700 Subject: [PATCH] Don't nest BOOL_SWITCH to work around gcc 7 bug --- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 38 ++++++++++--------- .../src/fmha_fprop_fp16_kernel.sm80.cu | 32 +++++++++------- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 25c7116..403ce92 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -27,25 +27,27 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { - BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { - auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel< - Kernel_traits, IsDropoutConst, IsCausalConst>; - if (params.seqlen_k == blocksize_c) { - kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel< - Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/1>; - } else if (params.seqlen_k == blocksize_c * 2) { - kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel< - Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/2>; - } - if( smem_size_dq_dk_dv >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - dim3 grid(params.b, params.h); - kernel<<>>(params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); }); } diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index cb05eee..67be427 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -59,21 +59,25 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, const int smem_size = fmha::get_dynamic_smem_size() + (loop_steps > 1 ? smem_size_softmax_lse : 0); + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - BOOL_SWITCH(launch_params.params.is_causal, IsCausalConst, [&] { - BOOL_SWITCH(launch_params.return_softmax, ReturnSoftmaxConst, [&] { - auto kernel = &fmha_fprop_fp16_sm80_loop_kernel< - Kernel_traits, IsDropoutConst, IsCausalConst, ReturnSoftmaxConst>; - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(launch_params.params.b, launch_params.params.h); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); - }); + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); }); }