From 19c4a4815e8de22df9fc6fbb728ab264b385144f Mon Sep 17 00:00:00 2001 From: wll Date: Fri, 12 May 2023 22:57:18 +0800 Subject: [PATCH] replace division with multiplication in GELU (#942) --- include/cutlass/epilogue/thread/activation.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 484f2ccd..79c6072c 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -515,7 +515,7 @@ struct GELU { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { return T(cutlass::constants::half() * scalar * - (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); + (cutlass::constants::one() + (T)erff((float)(scalar * cutlass::constants::half_root_two())))); } using Params = LinearCombinationGenericParams; @@ -531,7 +531,7 @@ struct GELU { CUTLASS_HOST_DEVICE float operator()(float const &scalar) const { return cutlass::constants::half() * scalar * - (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() )); + (cutlass::constants::one() + erff( scalar * cutlass::constants::half_root_two() )); } using Params = LinearCombinationGenericParams; @@ -547,7 +547,7 @@ struct GELU { CUTLASS_HOST_DEVICE double operator()(double const &scalar) const { return cutlass::constants::half() * scalar * - (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); + (cutlass::constants::one() + erf( scalar * cutlass::constants::half_root_two() )); } using Params = LinearCombinationGenericParams;