diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp index 35c1a5b..a58d601 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp @@ -30,7 +30,7 @@ // Below are knobs to extend FP32 accumulation for higher FP16 accuracy // Does not seem to affect the accuracy that much -// #define MMHA_USE_FP32_ACUM_FOR_FMA +#define MMHA_USE_FP32_ACUM_FOR_FMA // Seems to slightly improve the accuracy #define MMHA_USE_FP32_ACUM_FOR_OUT @@ -271,27 +271,6 @@ struct Qk_vec_acum_fp32_ { using Type = Float8_; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - //////////////////////////////////////////////////////////////////////////////////////////////////// template