More clean up

This commit is contained in:
KeDengMS 2021-04-18 04:29:20 +00:00
parent b7e43f5eb9
commit 83036ed646
2 changed files with 3 additions and 5 deletions

View File

@ -155,7 +155,7 @@ struct GELU<float> {
template <> template <>
struct GELU<double> { struct GELU<double> {
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
float operator()(double const &scalar) const { double operator()(double const &scalar) const {
return cutlass::constants::half<double>() * scalar * return cutlass::constants::half<double>() * scalar *
(cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() )); (cutlass::constants::one<double>() + erf( scalar / cutlass::constants::root_two<double>() ));
} }

View File

@ -148,14 +148,12 @@ TEST(Epilogue_thread_linear_combination_gelu, device_side_f16_f16_ptr) {
} }
cutlass::Array<ElementOutput, kCount> destination = linear_combination_op(accum, accum); cutlass::Array<ElementOutput, kCount> destination = linear_combination_op(accum, accum);
const float sqrt2 = sqrtf(2.0f);
cutlass::epilogue::thread::GELU<ElementOutput> gelu_func; cutlass::epilogue::thread::GELU<ElementOutput> gelu_func;
for (int i = 0; i < kCount; ++i) { for (int i = 0; i < kCount; ++i) {
ElementOutput expected = gelu_func(accum[i]); ElementOutput expected = gelu_func(accum[i]);
ElementOutput got = destination[i]; ElementOutput got = destination[i];
ElementOutput diff(fabs((float)(expected - got))); EXPECT_TRUE(expected == got);
EXPECT_TRUE(diff <= std::numeric_limits<cutlass::half_t>::epsilon());
} }
} }