flash-attention/csrc/flash_attn/src/fmha_bwd_hdim32.cu

17 lines
727 B
Plaintext
Raw Normal View History

// Copyright (c) 2022, Tri Dao.
2022-11-26 08:35:08 +08:00
// Splitting the different head dimensions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(params.is_bf16, ([&] {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}
}));
}