Support half precision sigmoid activation (#378)
* Support half precision sigmoid activation * introduce a vectorized variant using fast_tanh * move the math to fast_math.h * fixed compile * .raw() -> .to_half() Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
288af365db
commit
dceabd4c5a
@ -98,15 +98,7 @@ template <typename T>
|
||||
struct Sigmoid {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar) const {
|
||||
return T(1) / (T(1) + exp(-scalar));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Sigmoid<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &scalar) const {
|
||||
return 1.0f / (1.0f + expf(-scalar));
|
||||
return T(1) / (T(1) + fast_exp(-scalar));
|
||||
}
|
||||
};
|
||||
|
||||
@ -126,6 +118,30 @@ struct Sigmoid<Array<T, N> > {
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct Sigmoid<Array<half_t, N>> {
|
||||
using T = half_t;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||
plus<Array<T, N>> add;
|
||||
|
||||
#if defined(CUTLASS_USE_TANH_FOR_SIGMOID)
|
||||
multiplies<Array<T, N>> mul;
|
||||
fast_tanh_op<Array<T, N>> tanh;
|
||||
return mul(add(tanh(mul(z, cutlass::constants::half<T>())), cutlass::constants::one<T>()),
|
||||
cutlass::constants::half<T>());
|
||||
#else
|
||||
divides<Array<T, N>> div;
|
||||
negate<Array<T, N>> neg;
|
||||
fast_exp_op<Array<T, N>> fast_exp;
|
||||
return div(cutlass::constants::one<T>(),
|
||||
add(cutlass::constants::one<T>(),
|
||||
fast_exp(neg(z))));
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// SiLu (swish) operator introduced by Elfwing et al. in the following paper
|
||||
// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017)
|
||||
// https://arxiv.org/pdf/1702.03118.pdf
|
||||
|
@ -705,12 +705,21 @@ float fast_exp(float x) {
|
||||
CUTLASS_HOST_DEVICE
|
||||
double fast_exp(double x) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return ::exp(x);
|
||||
return ::expf(x);
|
||||
#else
|
||||
return std::exp(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_exp(half_t x) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
|
||||
return ::hexp(x.to_half());
|
||||
#else
|
||||
return fast_exp(float(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_log(float x) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
@ -767,6 +776,61 @@ half_t fast_tanh(half_t x) {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct fast_exp_op {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &rhs) const {
|
||||
return fast_exp(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
|
||||
template <int N>
|
||||
struct fast_exp_op<Array<half_t, N>> {
|
||||
CUTLASS_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const &rhs) const {
|
||||
|
||||
Array<half_t, N> result;
|
||||
|
||||
// use x2 specialization
|
||||
__half2 const *in = reinterpret_cast<__half2 const *>(&rhs);
|
||||
__half2 *out = reinterpret_cast<__half2 *>(&result);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 2; ++i) {
|
||||
out[i] = ::h2exp(in[i]);
|
||||
}
|
||||
|
||||
// residual
|
||||
if (N % 2) {
|
||||
half_t last = rhs[N - 1];
|
||||
result[N - 1] = half_t(::hexp(last.to_half()));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
#endif // #if defined(__CUDA_ARCH__)
|
||||
|
||||
template <typename T, int N>
|
||||
struct fast_exp_op<Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
|
||||
fast_exp_op<T> fast_op;
|
||||
Array<T, N> y;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = fast_op(rhs[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct fast_tanh_op {
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
Loading…
Reference in New Issue
Block a user