From 41a31b404b92c1b8ee2467c84208d45008c3d69b Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sat, 17 Apr 2021 22:10:19 +0000 Subject: [PATCH] Fixes to Gelu for half and fusion --- include/cutlass/epilogue/thread/activation.h | 11 ++++++++++- .../cutlass/epilogue/thread/linear_combination_gelu.h | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 49a63335..65161905 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -139,7 +139,7 @@ struct GELU { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { return T(cutlass::constants::half() * scalar * - (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() ))); + (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); } }; @@ -152,6 +152,15 @@ struct GELU { } }; +template <> +struct GELU { + CUTLASS_HOST_DEVICE + float operator()(double const &scalar) const { + return cutlass::constants::half() * scalar * + (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); + } +}; + template struct GELU > { CUTLASS_HOST_DEVICE diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index c47e89f1..baf3ebec 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -133,7 +133,7 @@ public: /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition) { + void set_k_partition(int k_partition, int k_partition_count) { if (k_partition) { beta_ = ElementCompute(1); }