diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 484f2ccd..79c6072c 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -515,7 +515,7 @@ struct GELU { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { return T(cutlass::constants::half() * scalar * - (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); + (cutlass::constants::one() + (T)erff((float)(scalar * cutlass::constants::half_root_two())))); } using Params = LinearCombinationGenericParams; @@ -531,7 +531,7 @@ struct GELU { CUTLASS_HOST_DEVICE float operator()(float const &scalar) const { return cutlass::constants::half() * scalar * - (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() )); + (cutlass::constants::one() + erff( scalar * cutlass::constants::half_root_two() )); } using Params = LinearCombinationGenericParams; @@ -547,7 +547,7 @@ struct GELU { CUTLASS_HOST_DEVICE double operator()(double const &scalar) const { return cutlass::constants::half() * scalar * - (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); + (cutlass::constants::one() + erf( scalar * cutlass::constants::half_root_two() )); } using Params = LinearCombinationGenericParams;