diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 189f2e2..ebf1b09 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -78,7 +78,14 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor &tenso // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif } } }