Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit

This commit is contained in:
Tri Dao 2022-09-12 14:21:29 -07:00
parent 13403e8115
commit 8166063a55

View File

@ -133,7 +133,7 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;