diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index fa45398..d13f1d5 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -63,7 +63,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 51d7576..4c336ec 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -45,7 +45,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { // Will only return softmax if dropout, to reduce compilation time. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_fwd_kernel; + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -78,7 +80,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 71259d0..bd05258 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -371,6 +371,7 @@ class RotaryEmbedding(torch.nn.Module): # or if we're switching from inference mode to training if ( seqlen > self._seq_len_cached + or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype or (self.training and self._cos_cached.is_inference()) diff --git a/setup.py b/setup.py index 33f3813..f5e17a4 100644 --- a/setup.py +++ b/setup.py @@ -201,7 +201,7 @@ if not SKIP_CUDA_BUILD: "--use_fast_math", # "--ptxas-options=-v", # "--ptxas-options=-O2", - "-lineinfo", + # "-lineinfo", ] + generator_flag + cc_flag