From 8166063a556e17e03e4a0697ba604def1eeb6a99 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 12 Sep 2022 14:21:29 -0700 Subject: [PATCH] Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit --- csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 67be427..8841d6b 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -133,7 +133,7 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else { - if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { + if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) { // TD [2022-06-05] Keep K in registers to reduce register spilling // Gives about 6% speedup compared to using block size 128. using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;