replace division with multiplication in GELU (#942)

This commit is contained in:
wll 2023-05-12 22:57:18 +08:00 committed by GitHub
parent fcfbd23e26
commit 19c4a4815e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -515,7 +515,7 @@ struct GELU {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return T(cutlass::constants::half<T>() * scalar *
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
(cutlass::constants::one<T>() + (T)erff((float)(scalar * cutlass::constants::half_root_two<T>()))));
}
using Params = LinearCombinationGenericParams<T>;
@ -531,7 +531,7 @@ struct GELU<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &scalar) const {
return cutlass::constants::half<float>() * scalar *
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
(cutlass::constants::one<float>() + erff( scalar * cutlass::constants::half_root_two<float>() ));
}
using Params = LinearCombinationGenericParams<float>;
@ -547,7 +547,7 @@ struct GELU<double> {
CUTLASS_HOST_DEVICE
double operator()(double const &scalar) const {
return cutlass::constants::half<double>() * scalar *
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
(cutlass::constants::one<double>() + erf( scalar * cutlass::constants::half_root_two<double>() ));
}
using Params = LinearCombinationGenericParams<double>;