Use block_size=128 for headdim=128 on SM80

Previously we were using block_size=256.
This commit is contained in:
Tri Dao 2022-10-21 13:19:54 -07:00
parent a44f48df5a
commit 7fc39832e2

View File

@ -115,6 +115,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M)) /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
); );
} }
// printf("smem_size = %d\n", smem_size);
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits); dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>( kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params); launch_params.params);
@ -156,20 +157,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} }
} else if (launch_params.params.d == 128) { } else if (launch_params.params.d == 128) {
if( launch_params.params.seqlen_k == 128 ) { // TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; // to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
} else { // some speedup (6-10%) for large batch size, but slows things down for smal batch size.
if (dprops->major == 8 && dprops->minor == 0) { // Now that we have better parallelism (over seqlen_q), block size 128 is faster for small
// TD [2022-06-05] Keep K in registers to reduce register spilling // batch size and only slightly slower (~3%) on large batch size.
// Gives about 6% speedup compared to using block size 128. // For causal=True, block size 128 seems always faster (for small & large batch size).
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; // So we're just gonna use block size 128 for simplicity.
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
} else { // Need to use the same block size as backward run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
} }
// if (launch_params.params.d == 64) { // if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;