diff --git a/test/unit/epilogue/thread/linear_combination.cu b/test/unit/epilogue/thread/linear_combination.cu index 86587fa2..bfee5405 100644 --- a/test/unit/epilogue/thread/linear_combination.cu +++ b/test/unit/epilogue/thread/linear_combination.cu @@ -141,23 +141,18 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) { LinearCombination linear_combination_op(params); - cutlass::Array source; cutlass::Array accum; for (int i = 0; i < kCount; ++i) { accum[i] = Element((float)i * 0.3f); - source[i] = ElementOutput(0); } - cutlass::Array destination = linear_combination_op(accum, source); + cutlass::Array destination = linear_combination_op(accum, accum); const float sqrt2 = sqrtf(2.0f); + cutlass::epilogue::thread::GELU gelu_func; for (int i = 0; i < kCount; ++i) { - float scalar = (float)accum[i]; - ElementOutput expected = ElementOutput( - 0.5f * scalar * (1.0f + erff(scalar / sqrt2)) - ); - + ElementOutput expected = gelu_func(accum[i]); ElementOutput got = destination[i]; ElementOutput diff(fabs((float)(expected - got))); EXPECT_TRUE(diff <= std::numeric_limits::epsilon());