From 83036ed64668500b137c5315ebb76c14fb2fc737 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sun, 18 Apr 2021 04:29:20 +0000 Subject: [PATCH] More clean up --- include/cutlass/epilogue/thread/activation.h | 2 +- test/unit/epilogue/thread/linear_combination.cu | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 65161905..bcfed6ca 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -155,7 +155,7 @@ struct GELU { template <> struct GELU { CUTLASS_HOST_DEVICE - float operator()(double const &scalar) const { + double operator()(double const &scalar) const { return cutlass::constants::half() * scalar * (cutlass::constants::one() + erf( scalar / cutlass::constants::root_two() )); } diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index bfee5405..48275ea2 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -148,14 +148,12 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { } cutlass::Array destination = linear_combination_op(accum, accum); - - const float sqrt2 = sqrtf(2.0f); cutlass::epilogue::thread::GELU gelu_func; + for (int i = 0; i < kCount; ++i) { ElementOutput expected = gelu_func(accum[i]); ElementOutput got = destination[i]; - ElementOutput diff(fabs((float)(expected - got))); - EXPECT_TRUE(diff <= std::numeric_limits::epsilon()); + EXPECT_TRUE(expected == got); } }