From 5c62d892faac6ae2131e4d222b4d8569c86dd0c4 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sun, 18 Apr 2021 04:09:34 +0000 Subject: [PATCH] Add test --- .../epilogue/thread/linear_combination.cu | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index 5ff188a3..86587fa2 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -29,6 +29,8 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/activation.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -119,3 +121,47 @@ TEST(Epilogue_thread_linear_combination, device_side_f16_f32_ptr) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { + + using Element = cutlass::half_t; + using ElementOutput = cutlass::half_t; + int const kCount = 8; + + using LinearCombination = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, + kCount, + Element, + Element>; + + Element alpha = Element(1); + Element beta = Element(0); + + typename LinearCombination::Params params(&alpha, &beta); + + LinearCombination linear_combination_op(params); + + cutlass::Array source; + cutlass::Array accum; + + for (int i = 0; i < kCount; ++i) { + accum[i] = Element((float)i * 0.3f); + source[i] = ElementOutput(0); + } + + cutlass::Array destination = linear_combination_op(accum, source); + + const float sqrt2 = sqrtf(2.0f); + for (int i = 0; i < kCount; ++i) { + float scalar = (float)accum[i]; + ElementOutput expected = ElementOutput( + 0.5f * scalar * (1.0f + erff(scalar / sqrt2)) + ); + + ElementOutput got = destination[i]; + ElementOutput diff(fabs((float)(expected - got))); + EXPECT_TRUE(diff <= std::numeric_limits::epsilon()); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file