diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 49a63335..65161905 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -139,7 +139,7 @@ struct GELU { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { return T(cutlass::constants::half() * scalar * - (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() ))); + (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); } }; @@ -152,6 +152,15 @@ struct GELU { } }; +template <> +struct GELU { + CUTLASS_HOST_DEVICE + float operator()(double const &scalar) const { + return cutlass::constants::half() * scalar * + (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); + } +}; + template struct GELU > { CUTLASS_HOST_DEVICE diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index c47e89f1..baf3ebec 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -133,7 +133,7 @@ public: /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition) { + void set_k_partition(int k_partition, int k_partition_count) { if (k_partition) { beta_ = ElementCompute(1); }