Use block_size=128 for headdim=128 on SM80

Previously we were using block_size=256.
This commit is contained in:
Tri Dao 2022-10-21 13:19:54 -07:00
parent a44f48df5a
commit 7fc39832e2

View File

@ -115,6 +115,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
);
}
// printf("smem_size = %d\n", smem_size);
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
@ -156,20 +157,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 128) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
if (dprops->major == 8 && dprops->minor == 0) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
// some speedup (6-10%) for large batch size, but slows things down for smal batch size.
// Now that we have better parallelism (over seqlen_q), block size 128 is faster for small
// batch size and only slightly slower (~3%) on large batch size.
// For causal=True, block size 128 seems always faster (for small & large batch size).
// So we're just gonna use block size 128 for simplicity.
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;