[Gen, FT] Use fp32 accum for FMA
This commit is contained in:
parent
f266fc7262
commit
be1afaa276
@ -30,7 +30,7 @@
|
||||
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
|
||||
|
||||
// Does not seem to affect the accuracy that much
|
||||
// #define MMHA_USE_FP32_ACUM_FOR_FMA
|
||||
#define MMHA_USE_FP32_ACUM_FOR_FMA
|
||||
|
||||
// Seems to slightly improve the accuracy
|
||||
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
||||
@ -271,27 +271,6 @@ struct Qk_vec_acum_fp32_<bf16_8_t> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Qk_vec_acum_fp32_<uint4> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
template<>
|
||||
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
struct Qk_vec_acum_fp32_<bf16_4_t> {
|
||||
using Type = Float4_;
|
||||
};
|
||||
template<>
|
||||
struct Qk_vec_acum_fp32_<bf16_8_t> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user