From 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 27 Mar 2024 19:12:11 -0700 Subject: [PATCH] Add the option for the macro and note (#893) --- csrc/flash_attn/src/softmax.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 189f2e2..ebf1b09 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -78,7 +78,14 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor &tenso // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif } } }