diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 67be427..8841d6b 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -133,7 +133,7 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else { - if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { + if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) { // TD [2022-06-05] Keep K in registers to reduce register spilling // Gives about 6% speedup compared to using block size 128. using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;