From be1afaa2765d3fc57fbf7a6810cd3b146cfdd0ee Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 3 Jan 2023 22:09:22 -0800 Subject: [PATCH] [Gen, FT] Use fp32 accum for FMA --- ...er_masked_multihead_attention_template.hpp | 23 +------------------ 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp index 35c1a5b..a58d601 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp @@ -30,7 +30,7 @@ // Below are knobs to extend FP32 accumulation for higher FP16 accuracy // Does not seem to affect the accuracy that much -// #define MMHA_USE_FP32_ACUM_FOR_FMA +#define MMHA_USE_FP32_ACUM_FOR_FMA // Seems to slightly improve the accuracy #define MMHA_USE_FP32_ACUM_FOR_OUT @@ -271,27 +271,6 @@ struct Qk_vec_acum_fp32_ { using Type = Float8_; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - //////////////////////////////////////////////////////////////////////////////////////////////////// template