replace division with multiplication in GELU (#942)
This commit is contained in:
parent
fcfbd23e26
commit
19c4a4815e
@ -515,7 +515,7 @@ struct GELU {
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &scalar) const {
|
T operator()(T const &scalar) const {
|
||||||
return T(cutlass::constants::half<T>() * scalar *
|
return T(cutlass::constants::half<T>() * scalar *
|
||||||
(cutlass::constants::one<T>() + (T)erff((float)(scalar / cutlass::constants::root_two<T>()))));
|
(cutlass::constants::one<T>() + (T)erff((float)(scalar * cutlass::constants::half_root_two<T>()))));
|
||||||
}
|
}
|
||||||
|
|
||||||
using Params = LinearCombinationGenericParams<T>;
|
using Params = LinearCombinationGenericParams<T>;
|
||||||
@ -531,7 +531,7 @@ struct GELU<float> {
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
float operator()(float const &scalar) const {
|
float operator()(float const &scalar) const {
|
||||||
return cutlass::constants::half<float>() * scalar *
|
return cutlass::constants::half<float>() * scalar *
|
||||||
(cutlass::constants::one<float>() + erff( scalar / cutlass::constants::root_two<float>() ));
|
(cutlass::constants::one<float>() + erff( scalar * cutlass::constants::half_root_two<float>() ));
|
||||||
}
|
}
|
||||||
|
|
||||||
using Params = LinearCombinationGenericParams<float>;
|
using Params = LinearCombinationGenericParams<float>;
|
||||||
@ -547,7 +547,7 @@ struct GELU<double> {
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
double operator()(double const &scalar) const {
|
double operator()(double const &scalar) const {
|
||||||
return cutlass::constants::half<double>() * scalar *
|
return cutlass::constants::half<double>() * scalar *
|
||||||
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
|
(cutlass::constants::one<double>() + erf( scalar * cutlass::constants::half_root_two<double>() ));
|
||||||
}
|
}
|
||||||
|
|
||||||
using Params = LinearCombinationGenericParams<double>;
|
using Params = LinearCombinationGenericParams<double>;
|
||||||
|
Loading…
Reference in New Issue
Block a user