diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index ce34be63..38597ccb 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -98,15 +98,7 @@ template struct Sigmoid { CUTLASS_HOST_DEVICE T operator()(T const &scalar) const { - return T(1) / (T(1) + exp(-scalar)); - } -}; - -template <> -struct Sigmoid { - CUTLASS_HOST_DEVICE - float operator()(float const &scalar) const { - return 1.0f / (1.0f + expf(-scalar)); + return T(1) / (T(1) + fast_exp(-scalar)); } }; @@ -126,6 +118,30 @@ struct Sigmoid > { } }; +template +struct Sigmoid> { + using T = half_t; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& z) const { + plus> add; + +#if defined(CUTLASS_USE_TANH_FOR_SIGMOID) + multiplies> mul; + fast_tanh_op> tanh; + return mul(add(tanh(mul(z, cutlass::constants::half())), cutlass::constants::one()), + cutlass::constants::half()); +#else + divides> div; + negate> neg; + fast_exp_op> fast_exp; + return div(cutlass::constants::one(), + add(cutlass::constants::one(), + fast_exp(neg(z)))); +#endif + } +}; + // SiLu (swish) operator introduced by Elfwing et al. in the following paper // "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017) // https://arxiv.org/pdf/1702.03118.pdf diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index b1650b94..8ab296c2 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -705,12 +705,21 @@ float fast_exp(float x) { CUTLASS_HOST_DEVICE double fast_exp(double x) { #if defined(__CUDA_ARCH__) - return ::exp(x); + return ::expf(x); #else return std::exp(x); #endif } +CUTLASS_HOST_DEVICE +float fast_exp(half_t x) { + #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) + return ::hexp(x.to_half()); + #else + return fast_exp(float(x)); + #endif +} + CUTLASS_HOST_DEVICE float fast_log(float x) { #if defined(__CUDA_ARCH__) @@ -767,6 +776,61 @@ half_t fast_tanh(half_t x) { ///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct fast_exp_op { + CUTLASS_HOST_DEVICE + T operator()(T const &rhs) const { + return fast_exp(rhs); + } +}; + +#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) +template +struct fast_exp_op> { + CUTLASS_DEVICE + Array operator()(Array const &rhs) const { + + Array result; + + // use x2 specialization + __half2 const *in = reinterpret_cast<__half2 const *>(&rhs); + __half2 *out = reinterpret_cast<__half2 *>(&result); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + out[i] = ::h2exp(in[i]); + } + + // residual + if (N % 2) { + half_t last = rhs[N - 1]; + result[N - 1] = half_t(::hexp(last.to_half())); + } + + return result; + } +}; +#endif // #if defined(__CUDA_ARCH__) + +template +struct fast_exp_op> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + + fast_exp_op fast_op; + Array y; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = fast_op(rhs[i]); + } + + return y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + template struct fast_tanh_op { CUTLASS_HOST_DEVICE