Don't nest BOOL_SWITCH to work around gcc 7 bug
This commit is contained in:
parent
d1fc80a3bb
commit
bc2c210254
@ -27,25 +27,27 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
|
||||
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
|
||||
|
||||
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
|
||||
Kernel_traits, IsDropoutConst, IsCausalConst>;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
|
||||
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/1>;
|
||||
} else if (params.seqlen_k == blocksize_c * 2) {
|
||||
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
|
||||
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/2>;
|
||||
}
|
||||
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
dim3 grid(params.b, params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
});
|
||||
auto kernel = params.is_causal
|
||||
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
|
||||
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
kernel = params.is_causal
|
||||
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
|
||||
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
|
||||
} else if (params.seqlen_k == blocksize_c * 2) {
|
||||
kernel = params.is_causal
|
||||
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
|
||||
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
|
||||
}
|
||||
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
dim3 grid(params.b, params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -59,21 +59,25 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
|
||||
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(launch_params.params.is_causal, IsCausalConst, [&] {
|
||||
BOOL_SWITCH(launch_params.return_softmax, ReturnSoftmaxConst, [&] {
|
||||
auto kernel = &fmha_fprop_fp16_sm80_loop_kernel<
|
||||
Kernel_traits, IsDropoutConst, IsCausalConst, ReturnSoftmaxConst>;
|
||||
if( smem_size >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
});
|
||||
});
|
||||
auto kernel = launch_params.params.is_causal
|
||||
? (launch_params.return_softmax
|
||||
? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
|
||||
: &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
|
||||
: (launch_params.return_softmax
|
||||
? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
|
||||
: &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
|
||||
if( smem_size >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user