diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 49a63335..bcfed6ca 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -139,7 +139,7 @@ struct GELU { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { return T(cutlass::constants::half() * scalar * - (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() ))); + (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); } }; @@ -152,6 +152,15 @@ struct GELU { } }; +template <> +struct GELU { + CUTLASS_HOST_DEVICE + double operator()(double const &scalar) const { + return cutlass::constants::half() * scalar * + (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); + } +}; + template struct GELU > { CUTLASS_HOST_DEVICE diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index c47e89f1..ebb08056 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -133,7 +133,8 @@ public: /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition) { + void set_k_partition(int k_partition, int k_partition_count) { + CUTLASS_UNUSED(k_partition_count); if (k_partition) { beta_ = ElementCompute(1); } diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index 5ff188a3..48275ea2 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -29,6 +29,8 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/activation.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -119,3 +121,40 @@ TEST(Epilogue_thread_linear_combination, device_side_f16_f32_ptr) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { + + using Element = cutlass::half_t; + using ElementOutput = cutlass::half_t; + int const kCount = 8; + + using LinearCombination = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, + kCount, + Element, + Element>; + + Element alpha = Element(1); + Element beta = Element(0); + + typename LinearCombination::Params params(&alpha, &beta); + + LinearCombination linear_combination_op(params); + + cutlass::Array accum; + + for (int i = 0; i < kCount; ++i) { + accum[i] = Element((float)i * 0.3f); + } + + cutlass::Array destination = linear_combination_op(accum, accum); + cutlass::epilogue::thread::GELU gelu_func; + + for (int i = 0; i < kCount; ++i) { + ElementOutput expected = gelu_func(accum[i]); + ElementOutput got = destination[i]; + EXPECT_TRUE(expected == got); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file