Merge pull request #52 from bob80333/main
Make flash attention compile on Windows.
This commit is contained in:
commit
88dc2040a0
@ -2,6 +2,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include "static_switch.h"
|
#include "static_switch.h"
|
||||||
|
#include "fp16_switch.h"
|
||||||
#include "fmha.h"
|
#include "fmha.h"
|
||||||
#include "fmha_dgrad_kernel_1xN_loop.h"
|
#include "fmha_dgrad_kernel_1xN_loop.h"
|
||||||
|
|
||||||
@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
|
|||||||
}
|
}
|
||||||
|
|
||||||
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) {
|
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) {
|
||||||
BOOL_SWITCH(params.is_bf16, IsBf16Const, [&] {
|
// work around for MSVC issue
|
||||||
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
|
FP16_SWITCH(params.is_bf16, [&] {
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
if (params.d == 16) {
|
if (params.d == 16) {
|
||||||
if( params.seqlen_k == 128 ) {
|
if( params.seqlen_k == 128 ) {
|
||||||
|
|||||||
@ -29,6 +29,7 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
#include "static_switch.h"
|
#include "static_switch.h"
|
||||||
|
#include "fp16_switch.h"
|
||||||
#include "fmha.h"
|
#include "fmha.h"
|
||||||
#include "fmha_fprop_kernel_1xN.h"
|
#include "fmha_fprop_kernel_1xN.h"
|
||||||
|
|
||||||
@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
|||||||
|
|
||||||
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
||||||
const bool configure) {
|
const bool configure) {
|
||||||
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
|
FP16_SWITCH(launch_params.params.is_bf16, [&] {
|
||||||
using elem_type = std::conditional<IsBf16Const, __nv_bfloat16, __half>::type;
|
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
if (launch_params.params.d == 16) {
|
if (launch_params.params.d == 16) {
|
||||||
if( launch_params.params.seqlen_k == 128 ) {
|
if( launch_params.params.seqlen_k == 128 ) {
|
||||||
|
|||||||
27
csrc/flash_attn/src/fp16_switch.h
Normal file
27
csrc/flash_attn/src/fp16_switch.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||||
|
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||||
|
|
||||||
|
// modified from static_switch.h
|
||||||
|
// because MSVC cannot handle std::conditional with constexpr variable
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/// @param COND - a boolean expression to switch by
|
||||||
|
/// @param ... - code to execute for true and false
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```
|
||||||
|
/// FP16_SWITCH(flag, [&] {
|
||||||
|
/// some_function(...);
|
||||||
|
/// });
|
||||||
|
/// ```
|
||||||
|
#define FP16_SWITCH(COND, ...) \
|
||||||
|
[&] { \
|
||||||
|
if (COND) { \
|
||||||
|
using elem_type = __nv_bfloat16; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
using elem_type = __half; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
3
setup.py
3
setup.py
@ -125,10 +125,11 @@ ext_modules.append(
|
|||||||
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
|
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
|
||||||
],
|
],
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": ["-O3"] + generator_flag,
|
"cxx": ["-O3", "-std=c++17"] + generator_flag,
|
||||||
"nvcc": append_nvcc_threads(
|
"nvcc": append_nvcc_threads(
|
||||||
[
|
[
|
||||||
"-O3",
|
"-O3",
|
||||||
|
"-std=c++17",
|
||||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
"--expt-relaxed-constexpr",
|
"--expt-relaxed-constexpr",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user