replace division with multiplication in GELU (#942)
This commit is contained in:
parent
fcfbd23e26
commit
19c4a4815e
@ -515,7 +515,7 @@ struct GELU {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar) const {
|
||||
return T(cutlass::constants::half<T>() * scalar *
|
||||
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
|
||||
(cutlass::constants::one<T>() + (T)erff((float)(scalar * cutlass::constants::half_root_two<T>()))));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<T>;
|
||||
@ -531,7 +531,7 @@ struct GELU<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &scalar) const {
|
||||
return cutlass::constants::half<float>() * scalar *
|
||||
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
|
||||
(cutlass::constants::one<float>() + erff( scalar * cutlass::constants::half_root_two<float>() ));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
@ -547,7 +547,7 @@ struct GELU<double> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
double operator()(double const &scalar) const {
|
||||
return cutlass::constants::half<double>() * scalar *
|
||||
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
|
||||
(cutlass::constants::one<double>() + erf( scalar * cutlass::constants::half_root_two<double>() ));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<double>;
|
||||
|
Loading…
Reference in New Issue
Block a user