Merge pull request #52 from bob80333/main
Make flash attention compile on Windows.
This commit is contained in:
commit
88dc2040a0
@ -2,6 +2,7 @@
|
||||
*/
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "fp16_switch.h"
|
||||
#include "fmha.h"
|
||||
#include "fmha_dgrad_kernel_1xN_loop.h"
|
||||
|
||||
@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
|
||||
}
|
||||
|
||||
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.is_bf16, IsBf16Const, [&] {
|
||||
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
|
||||
// work around for MSVC issue
|
||||
FP16_SWITCH(params.is_bf16, [&] {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (params.d == 16) {
|
||||
if( params.seqlen_k == 128 ) {
|
||||
|
||||
@ -29,6 +29,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "fp16_switch.h"
|
||||
#include "fmha.h"
|
||||
#include "fmha_fprop_kernel_1xN.h"
|
||||
|
||||
@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
|
||||
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
|
||||
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
|
||||
FP16_SWITCH(launch_params.params.is_bf16, [&] {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (launch_params.params.d == 16) {
|
||||
if( launch_params.params.seqlen_k == 128 ) {
|
||||
|
||||
27
csrc/flash_attn/src/fp16_switch.h
Normal file
27
csrc/flash_attn/src/fp16_switch.h
Normal file
@ -0,0 +1,27 @@
|
||||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
// modified from static_switch.h
|
||||
// because MSVC cannot handle std::conditional with constexpr variable
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// FP16_SWITCH(flag, [&] {
|
||||
/// some_function(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define FP16_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
using elem_type = __nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
using elem_type = __half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
3
setup.py
3
setup.py
@ -125,10 +125,11 @@ ext_modules.append(
|
||||
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"] + generator_flag,
|
||||
"cxx": ["-O3", "-std=c++17"] + generator_flag,
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user