flash-attention/csrc/flash_attn/src/fmha_bwd_hdim32.cu
Tri Dao a1f49a2b92 [Compilation] Change BOOL_SWITCH to fix Windows compilation
Follow xFormers's DISTPATCH_BOOL. Haven't tested it on Windows.
2023-01-06 14:40:58 -08:00

17 lines
727 B
Plaintext

// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(params.is_bf16, ([&] {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}
}));
}