[Gen, FT] Use fp32 accum for FMA

This commit is contained in:
Tri Dao 2023-01-03 22:09:22 -08:00
parent f266fc7262
commit be1afaa276

View File

@ -30,7 +30,7 @@
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
// Does not seem to affect the accuracy that much
// #define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_FP32_ACUM_FOR_FMA
// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
@ -271,27 +271,6 @@ struct Qk_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
template<>
struct Qk_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
using Type = float;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct Qk_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>