diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 9dc8e6a..aa39992 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -115,6 +115,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M)) ); } + // printf("smem_size = %d\n", smem_size); dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits); kernel<<>>( launch_params.params); @@ -156,20 +157,16 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, run_fmha_fp16_sm80_loop_(launch_params, configure); } } else if (launch_params.params.d == 128) { - if( launch_params.params.seqlen_k == 128 ) { - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - if (dprops->major == 8 && dprops->minor == 0) { - // TD [2022-06-05] Keep K in registers to reduce register spilling - // Gives about 6% speedup compared to using block size 128. - using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { // Need to use the same block size as backward - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } + // TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory + // to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB, + // reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives + // some speedup (6-10%) for large batch size, but slows things down for smal batch size. + // Now that we have better parallelism (over seqlen_q), block size 128 is faster for small + // batch size and only slightly slower (~3%) on large batch size. + // For causal=True, block size 128 seems always faster (for small & large batch size). + // So we're just gonna use block size 128 for simplicity. + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); } // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;