Fixes to Gelu for half and fusion

This commit is contained in:
KeDengMS 2021-04-17 22:10:19 +00:00
parent c77a524459
commit 41a31b404b
2 changed files with 11 additions and 2 deletions

View File

@ -139,7 +139,7 @@ struct GELU {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return T(cutlass::constants::half<T>() * scalar *
(cutlass::constants::one<T>() + erff( scalar / cutlass::constants::root_two<T>() )));
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
}
};
@ -152,6 +152,15 @@ struct GELU<float> {
}
};
template <>
struct GELU<double> {
CUTLASS_HOST_DEVICE
float operator()(double const &scalar) const {
return cutlass::constants::half<double>() * scalar *
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
}
};
template <typename T, int N>
struct GELU<Array<T, N> > {
CUTLASS_HOST_DEVICE

View File

@ -133,7 +133,7 @@ public:
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition) {
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}