Support half precision sigmoid activation (#378)
* Support half precision sigmoid activation * introduce a vectorized variant using fast_tanh * move the math to fast_math.h * fixed compile * .raw() -> .to_half() Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
288af365db
commit
dceabd4c5a
@ -98,15 +98,7 @@ template <typename T>
|
|||||||
struct Sigmoid {
|
struct Sigmoid {
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &scalar) const {
|
T operator()(T const &scalar) const {
|
||||||
return T(1) / (T(1) + exp(-scalar));
|
return T(1) / (T(1) + fast_exp(-scalar));
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Sigmoid<float> {
|
|
||||||
CUTLASS_HOST_DEVICE
|
|
||||||
float operator()(float const &scalar) const {
|
|
||||||
return 1.0f / (1.0f + expf(-scalar));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -126,6 +118,30 @@ struct Sigmoid<Array<T, N> > {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
struct Sigmoid<Array<half_t, N>> {
|
||||||
|
using T = half_t;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||||
|
plus<Array<T, N>> add;
|
||||||
|
|
||||||
|
#if defined(CUTLASS_USE_TANH_FOR_SIGMOID)
|
||||||
|
multiplies<Array<T, N>> mul;
|
||||||
|
fast_tanh_op<Array<T, N>> tanh;
|
||||||
|
return mul(add(tanh(mul(z, cutlass::constants::half<T>())), cutlass::constants::one<T>()),
|
||||||
|
cutlass::constants::half<T>());
|
||||||
|
#else
|
||||||
|
divides<Array<T, N>> div;
|
||||||
|
negate<Array<T, N>> neg;
|
||||||
|
fast_exp_op<Array<T, N>> fast_exp;
|
||||||
|
return div(cutlass::constants::one<T>(),
|
||||||
|
add(cutlass::constants::one<T>(),
|
||||||
|
fast_exp(neg(z))));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// SiLu (swish) operator introduced by Elfwing et al. in the following paper
|
// SiLu (swish) operator introduced by Elfwing et al. in the following paper
|
||||||
// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017)
|
// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017)
|
||||||
// https://arxiv.org/pdf/1702.03118.pdf
|
// https://arxiv.org/pdf/1702.03118.pdf
|
||||||
|
@ -705,12 +705,21 @@ float fast_exp(float x) {
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
double fast_exp(double x) {
|
double fast_exp(double x) {
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
return ::exp(x);
|
return ::expf(x);
|
||||||
#else
|
#else
|
||||||
return std::exp(x);
|
return std::exp(x);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
float fast_exp(half_t x) {
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
|
||||||
|
return ::hexp(x.to_half());
|
||||||
|
#else
|
||||||
|
return fast_exp(float(x));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
float fast_log(float x) {
|
float fast_log(float x) {
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
@ -767,6 +776,61 @@ half_t fast_tanh(half_t x) {
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct fast_exp_op {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
T operator()(T const &rhs) const {
|
||||||
|
return fast_exp(rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
|
||||||
|
template <int N>
|
||||||
|
struct fast_exp_op<Array<half_t, N>> {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
Array<half_t, N> operator()(Array<half_t, N> const &rhs) const {
|
||||||
|
|
||||||
|
Array<half_t, N> result;
|
||||||
|
|
||||||
|
// use x2 specialization
|
||||||
|
__half2 const *in = reinterpret_cast<__half2 const *>(&rhs);
|
||||||
|
__half2 *out = reinterpret_cast<__half2 *>(&result);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < N / 2; ++i) {
|
||||||
|
out[i] = ::h2exp(in[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// residual
|
||||||
|
if (N % 2) {
|
||||||
|
half_t last = rhs[N - 1];
|
||||||
|
result[N - 1] = half_t(::hexp(last.to_half()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // #if defined(__CUDA_ARCH__)
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
struct fast_exp_op<Array<T, N>> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||||
|
|
||||||
|
fast_exp_op<T> fast_op;
|
||||||
|
Array<T, N> y;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
y[i] = fast_op(rhs[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct fast_tanh_op {
|
struct fast_tanh_op {
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
|
Loading…
Reference in New Issue
Block a user