From 41a31b404b92c1b8ee2467c84208d45008c3d69b Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sat, 17 Apr 2021 22:10:19 +0000 Subject: [PATCH 1/5] Fixes to Gelu for half and fusion --- include/cutlass/epilogue/thread/activation.h | 11 ++++++++++- .../cutlass/epilogue/thread/linear_combination_gelu.h | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 49a63335..65161905 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -139,7 +139,7 @@ struct GELU { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { return T(cutlass::constants::half() * scalar * - (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() ))); + (cutlass::constants::one() + (T)erff((float)(scalar / cutlass::constants::root_two())))); } }; @@ -152,6 +152,15 @@ struct GELU { } }; +template <> +struct GELU { + CUTLASS_HOST_DEVICE + float operator()(double const &scalar) const { + return cutlass::constants::half() * scalar * + (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); + } +}; + template struct GELU > { CUTLASS_HOST_DEVICE diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index c47e89f1..baf3ebec 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -133,7 +133,7 @@ public: /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition) { + void set_k_partition(int k_partition, int k_partition_count) { if (k_partition) { beta_ = ElementCompute(1); } From 5c62d892faac6ae2131e4d222b4d8569c86dd0c4 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sun, 18 Apr 2021 04:09:34 +0000 Subject: [PATCH 2/5] Add test --- .../epilogue/thread/linear_combination.cu | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index 5ff188a3..86587fa2 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -29,6 +29,8 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/activation.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -119,3 +121,47 @@ TEST(Epilogue_thread_linear_combination, device_side_f16_f32_ptr) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { + + using Element = cutlass::half_t; + using ElementOutput = cutlass::half_t; + int const kCount = 8; + + using LinearCombination = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, + kCount, + Element, + Element>; + + Element alpha = Element(1); + Element beta = Element(0); + + typename LinearCombination::Params params(&alpha, &beta); + + LinearCombination linear_combination_op(params); + + cutlass::Array source; + cutlass::Array accum; + + for (int i = 0; i < kCount; ++i) { + accum[i] = Element((float)i * 0.3f); + source[i] = ElementOutput(0); + } + + cutlass::Array destination = linear_combination_op(accum, source); + + const float sqrt2 = sqrtf(2.0f); + for (int i = 0; i < kCount; ++i) { + float scalar = (float)accum[i]; + ElementOutput expected = ElementOutput( + 0.5f * scalar * (1.0f + erff(scalar / sqrt2)) + ); + + ElementOutput got = destination[i]; + ElementOutput diff(fabs((float)(expected - got))); + EXPECT_TRUE(diff <= std::numeric_limits::epsilon()); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file From b7e43f5eb93e1ee3372c0f0a64969d190d557af1 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sun, 18 Apr 2021 04:24:25 +0000 Subject: [PATCH 3/5] Clean up --- test/unit/epilogue/thread/linear_combination.cu | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index 86587fa2..bfee5405 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -141,23 +141,18 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { LinearCombination linear_combination_op(params); - cutlass::Array source; cutlass::Array accum; for (int i = 0; i < kCount; ++i) { accum[i] = Element((float)i * 0.3f); - source[i] = ElementOutput(0); } - cutlass::Array destination = linear_combination_op(accum, source); + cutlass::Array destination = linear_combination_op(accum, accum); const float sqrt2 = sqrtf(2.0f); + cutlass::epilogue::thread::GELU gelu_func; for (int i = 0; i < kCount; ++i) { - float scalar = (float)accum[i]; - ElementOutput expected = ElementOutput( - 0.5f * scalar * (1.0f + erff(scalar / sqrt2)) - ); - + ElementOutput expected = gelu_func(accum[i]); ElementOutput got = destination[i]; ElementOutput diff(fabs((float)(expected - got))); EXPECT_TRUE(diff <= std::numeric_limits::epsilon()); From 83036ed64668500b137c5315ebb76c14fb2fc737 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sun, 18 Apr 2021 04:29:20 +0000 Subject: [PATCH 4/5] More clean up --- include/cutlass/epilogue/thread/activation.h | 2 +- test/unit/epilogue/thread/linear_combination.cu | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 65161905..bcfed6ca 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -155,7 +155,7 @@ struct GELU { template <> struct GELU { CUTLASS_HOST_DEVICE - float operator()(double const &scalar) const { + double operator()(double const &scalar) const { return cutlass::constants::half() * scalar * (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); } diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index bfee5405..48275ea2 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -148,14 +148,12 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { } cutlass::Array destination = linear_combination_op(accum, accum); - - const float sqrt2 = sqrtf(2.0f); cutlass::epilogue::thread::GELU gelu_func; + for (int i = 0; i < kCount; ++i) { ElementOutput expected = gelu_func(accum[i]); ElementOutput got = destination[i]; - ElementOutput diff(fabs((float)(expected - got))); - EXPECT_TRUE(diff <= std::numeric_limits::epsilon()); + EXPECT_TRUE(expected == got); } } From 0b74c8f473e04871adef767c6d2bb1b8ee8f213b Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Mon, 19 Apr 2021 23:36:06 +0000 Subject: [PATCH 5/5] Address CR --- include/cutlass/epilogue/thread/linear_combination_gelu.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index baf3ebec..ebb08056 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -134,6 +134,7 @@ public: /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE void set_k_partition(int k_partition, int k_partition_count) { + CUTLASS_UNUSED(k_partition_count); if (k_partition) { beta_ = ElementCompute(1); }