From 6a89b2f1216547620f25cf389699087a50b30355 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Sep 2023 22:59:41 -0700 Subject: [PATCH] Remove constexpr in launch template to fix CI compilation --- csrc/cutlass | 2 +- csrc/flash_attn/src/flash_bwd_launch_template.h | 8 ++++---- csrc/flash_attn/src/flash_fwd_launch_template.h | 4 ++-- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/cutlass b/csrc/cutlass index 3a8f57a..34fd980 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit 3a8f57a3c89cfff7aa686e95f13d9ad850f61898 +Subproject commit 34fd98056b69fbf7f0929b3f734bb5f00642e2c9 diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 81b2e4c..900fc26 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { + if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } @@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); auto kernel_dq = &flash_bwd_convert_dq_kernel; - if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) { + if (Kernel_traits::kSmemdQSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } @@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; - if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { + if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } @@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); auto kernel_dkv = &flash_bwd_convert_dkv_kernel; - if constexpr(Kernel_traits::kSmemKVSize >= 48 * 1024) { + if (Kernel_traits::kSmemKVSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize)); } diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index d036be4..2aa34a0 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -46,7 +46,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // Will only return softmax if dropout, to reduce compilation time. auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; - if constexpr(smem_size >= 48 * 1024) { + if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } @@ -74,7 +74,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; - if constexpr(smem_size >= 48 * 1024) { + if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 7a51e3d..5efd449 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.1.2.post1" +__version__ = "2.1.2.post2" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index 0a8d366..133703e 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.1.2.post1 +RUN pip install flash-attn==2.1.2.post2 # Install CUDA extensions for cross-entropy, fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.1.2.post1 \ + && cd flash-attention && git checkout v2.1.2.post2 \ && cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/xentropy && pip install . && cd ../../ \