2022-11-26 08:29:17 +08:00
|
|
|
// Copyright (c) 2022, Tri Dao.
|
|
|
|
|
|
2022-11-26 08:35:08 +08:00
|
|
|
// Splitting the different head dimensions to different files to speed up compilation.
|
2022-11-26 08:29:17 +08:00
|
|
|
|
|
|
|
|
#include "fmha_bwd_launch_template.h"
|
|
|
|
|
|
|
|
|
|
void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
2023-01-07 06:40:58 +08:00
|
|
|
FP16_SWITCH(params.is_bf16, ([&] {
|
2022-11-26 08:29:17 +08:00
|
|
|
if (params.seqlen_k == 128) {
|
|
|
|
|
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
|
|
|
|
|
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
|
|
|
|
} else if (params.seqlen_k >= 256) {
|
|
|
|
|
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
|
|
|
|
|
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
|
|
|
|
}
|
2022-12-07 06:16:04 +08:00
|
|
|
}));
|
2022-11-26 08:29:17 +08:00
|
|
|
}
|