From 43617deab916bc0fea9f823ef99f31772352f994 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 18 Sep 2023 11:46:15 -0700 Subject: [PATCH] Remove template for (IsEvenMN=T, IsEvenK=F) to speed up compilation --- csrc/flash_attn/src/flash_bwd_launch_template.h | 6 ++++-- csrc/flash_attn/src/flash_fwd_launch_template.h | 6 ++++-- training/Dockerfile | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 900fc26..f4f2388 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -62,7 +62,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -101,7 +102,8 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index ef713a8..9c8c750 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -44,7 +44,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -76,7 +77,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - auto kernel = &flash_fwd_splitkv_kernel; + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { diff --git a/training/Dockerfile b/training/Dockerfile index bc6f446..fbea6cb 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -87,7 +87,7 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention RUN pip install flash-attn==2.2.3.post2 -# Install CUDA extensions for cross-entropy, fused dense, layer norm +# Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ && cd flash-attention && git checkout v2.2.3.post2 \ && cd csrc/layer_norm && pip install . && cd ../../ \