From b7e43f5eb93e1ee3372c0f0a64969d190d557af1 Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Sun, 18 Apr 2021 04:24:25 +0000 Subject: [PATCH] Clean up --- test/unit/epilogue/thread/linear_combination.cu | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index 86587fa2..bfee5405 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -141,23 +141,18 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { LinearCombination linear_combination_op(params); - cutlass::Array source; cutlass::Array accum; for (int i = 0; i < kCount; ++i) { accum[i] = Element((float)i * 0.3f); - source[i] = ElementOutput(0); } - cutlass::Array destination = linear_combination_op(accum, source); + cutlass::Array destination = linear_combination_op(accum, accum); const float sqrt2 = sqrtf(2.0f); + cutlass::epilogue::thread::GELU gelu_func; for (int i = 0; i < kCount; ++i) { - float scalar = (float)accum[i]; - ElementOutput expected = ElementOutput( - 0.5f * scalar * (1.0f + erff(scalar / sqrt2)) - ); - + ElementOutput expected = gelu_func(accum[i]); ElementOutput got = destination[i]; ElementOutput diff(fabs((float)(expected - got))); EXPECT_TRUE(diff <= std::numeric_limits::epsilon());