[Gen, FT] Use fp32 accum for FMA
This commit is contained in:
parent
f266fc7262
commit
be1afaa276
@ -30,7 +30,7 @@
|
|||||||
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
|
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
|
||||||
|
|
||||||
// Does not seem to affect the accuracy that much
|
// Does not seem to affect the accuracy that much
|
||||||
// #define MMHA_USE_FP32_ACUM_FOR_FMA
|
#define MMHA_USE_FP32_ACUM_FOR_FMA
|
||||||
|
|
||||||
// Seems to slightly improve the accuracy
|
// Seems to slightly improve the accuracy
|
||||||
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
||||||
@ -271,27 +271,6 @@ struct Qk_vec_acum_fp32_<bf16_8_t> {
|
|||||||
using Type = Float8_;
|
using Type = Float8_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
|
||||||
struct Qk_vec_acum_fp32_<uint4> {
|
|
||||||
using Type = Float8_;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
|
|
||||||
using Type = float;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Qk_vec_acum_fp32_<bf16_4_t> {
|
|
||||||
using Type = Float4_;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Qk_vec_acum_fp32_<bf16_8_t> {
|
|
||||||
using Type = Float8_;
|
|
||||||
};
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user