Simplify BOOL_SWITCH macro to fix compiling error on gcc 7
This commit is contained in:
parent
a84d07283c
commit
8a2ece89f7
@ -36,7 +36,8 @@
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/UnpackRaw.cuh>
|
||||
|
||||
#include <fmha_utils.h>
|
||||
|
||||
|
||||
@ -5,9 +5,8 @@
|
||||
#include "fmha_bwd_launch_template.h"
|
||||
|
||||
void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// work around for MSVC issue
|
||||
FP16_SWITCH(params.is_bf16, [&] {
|
||||
FP16_SWITCH(params.is_bf16, ({
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
});
|
||||
}));
|
||||
}
|
||||
@ -5,8 +5,7 @@
|
||||
#include "fmha_bwd_launch_template.h"
|
||||
|
||||
void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// work around for MSVC issue
|
||||
FP16_SWITCH(params.is_bf16, [&] {
|
||||
FP16_SWITCH(params.is_bf16, ({
|
||||
if (params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
@ -14,5 +13,5 @@ void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}));
|
||||
}
|
||||
@ -5,8 +5,7 @@
|
||||
#include "fmha_bwd_launch_template.h"
|
||||
|
||||
void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// work around for MSVC issue
|
||||
FP16_SWITCH(params.is_bf16, [&] {
|
||||
FP16_SWITCH(params.is_bf16, ({
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
@ -27,5 +26,5 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}
|
||||
}
|
||||
});
|
||||
}));
|
||||
}
|
||||
@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "fp16_switch.h"
|
||||
#include "fmha.h"
|
||||
#include "fmha_dgrad_kernel_1xN_loop.h"
|
||||
|
||||
@ -62,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo
|
||||
|
||||
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(is_dropout, IsDropoutConst, ({
|
||||
auto kernel = params.is_causal
|
||||
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
|
||||
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
|
||||
@ -111,5 +110,5 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo
|
||||
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
}
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
@ -5,8 +5,8 @@
|
||||
#include "fmha_fwd_launch_template.h"
|
||||
|
||||
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, [&] {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ({
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
});
|
||||
}));
|
||||
}
|
||||
@ -5,7 +5,7 @@
|
||||
#include "fmha_fwd_launch_template.h"
|
||||
|
||||
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, [&] {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ({
|
||||
if (launch_params.params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
@ -13,5 +13,5 @@ void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}
|
||||
});
|
||||
}));
|
||||
}
|
||||
@ -5,7 +5,7 @@
|
||||
#include "fmha_fwd_launch_template.h"
|
||||
|
||||
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, [&] {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ({
|
||||
if (launch_params.params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
@ -13,5 +13,5 @@ void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "fp16_switch.h"
|
||||
#include "fmha.h"
|
||||
#include "fmha_fprop_kernel_1xN.h"
|
||||
|
||||
@ -57,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ({
|
||||
auto kernel = launch_params.params.is_causal
|
||||
? (launch_params.return_softmax
|
||||
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
|
||||
@ -88,5 +87,5 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
// modified from static_switch.h
|
||||
// because MSVC cannot handle std::conditional with constexpr variable
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// FP16_SWITCH(flag, [&] {
|
||||
/// some_function(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define FP16_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
using elem_type = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
using elem_type = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
@ -9,17 +9,27 @@
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// BOOL_SWITCH(flag, BoolConst, ({
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// }));
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, CODE) \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
CODE; \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
CODE; \
|
||||
}
|
||||
|
||||
// modified from BOOL_SWITCH
|
||||
// because MSVC cannot handle std::conditional with constexpr variable
|
||||
#define FP16_SWITCH(COND, CODE) \
|
||||
if (COND) { \
|
||||
using elem_type = __nv_bfloat16; \
|
||||
CODE; \
|
||||
} else { \
|
||||
using elem_type = __half; \
|
||||
CODE; \
|
||||
} \
|
||||
}()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user