Use block_size=128 for headdim=128 on SM80
Previously we were using block_size=256.
This commit is contained in:
parent
a44f48df5a
commit
7fc39832e2
@ -115,6 +115,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
|||||||
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
|
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
// printf("smem_size = %d\n", smem_size);
|
||||||
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
|
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
|
||||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||||
launch_params.params);
|
launch_params.params);
|
||||||
@ -156,20 +157,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
|||||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||||
}
|
}
|
||||||
} else if (launch_params.params.d == 128) {
|
} else if (launch_params.params.d == 128) {
|
||||||
if( launch_params.params.seqlen_k == 128 ) {
|
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
|
||||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
|
||||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
|
||||||
} else {
|
// some speedup (6-10%) for large batch size, but slows things down for smal batch size.
|
||||||
if (dprops->major == 8 && dprops->minor == 0) {
|
// Now that we have better parallelism (over seqlen_q), block size 128 is faster for small
|
||||||
// TD [2022-06-05] Keep K in registers to reduce register spilling
|
// batch size and only slightly slower (~3%) on large batch size.
|
||||||
// Gives about 6% speedup compared to using block size 128.
|
// For causal=True, block size 128 seems always faster (for small & large batch size).
|
||||||
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
|
// So we're just gonna use block size 128 for simplicity.
|
||||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||||
} else { // Need to use the same block size as backward
|
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
|
||||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// if (launch_params.params.d == 64) {
|
// if (launch_params.params.d == 64) {
|
||||||
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user