set kIsHeavy member variables (#1012)
* set kIsHeavy member variables * correct kIsHeavy value for Tanh * set kIsHeavy=false for HardSwish --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
61a38f83dc
commit
5f13dcad78
@ -37,6 +37,7 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/constants.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/array.h"
|
||||
@ -54,7 +55,7 @@ namespace thread {
|
||||
// Identity operator
|
||||
template <typename T>
|
||||
struct Identity {
|
||||
static const bool kIsHeavy=false;
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T value) const {
|
||||
@ -128,7 +129,8 @@ struct Scale<Activation<T>> {
|
||||
/// Always put threshold in the right hand side of max to propagate NaN.
|
||||
template <typename T>
|
||||
struct ReLu {
|
||||
static const bool kIsHeavy=false;
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const & threshold, T value) const {
|
||||
maximum<T> mx;
|
||||
@ -149,7 +151,8 @@ using ReLU = ReLu<T>;
|
||||
|
||||
template <typename T, int N>
|
||||
struct ReLu<Array<T, N>> {
|
||||
static const bool kIsHeavy=false;
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
|
||||
maximum<Array<T, N>> mx;
|
||||
@ -207,6 +210,9 @@ struct Clamp<Array<T,N>> {
|
||||
// Leaky Relu operator
|
||||
template <typename T>
|
||||
struct LeakyReLU {
|
||||
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
struct Arguments {
|
||||
T leaky_alpha = T(0);
|
||||
};
|
||||
@ -225,6 +231,9 @@ struct LeakyReLU {
|
||||
|
||||
template <typename T, int N>
|
||||
struct LeakyReLU<Array<T, N> > {
|
||||
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
using Arguments = typename LeakyReLU<T>::Arguments;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -249,6 +258,8 @@ struct LeakyReLU<Array<T, N> > {
|
||||
// Tanh operator
|
||||
template <typename T>
|
||||
struct Tanh {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value) const {
|
||||
return fast_tanh(value);
|
||||
@ -257,6 +268,8 @@ struct Tanh {
|
||||
|
||||
template <typename T, int N>
|
||||
struct Tanh<Array<T, N> > {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
@ -274,6 +287,7 @@ struct Tanh<Array<T, N> > {
|
||||
template <int N>
|
||||
struct Tanh<Array<half_t, N>> {
|
||||
using T = half_t;
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||
@ -285,6 +299,8 @@ struct Tanh<Array<half_t, N>> {
|
||||
// Sigmoid operator
|
||||
template <typename T>
|
||||
struct Sigmoid {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value) const {
|
||||
return T(1) / (T(1) + fast_exp(-value));
|
||||
@ -293,6 +309,8 @@ struct Sigmoid {
|
||||
|
||||
template <typename T, int N>
|
||||
struct Sigmoid<Array<T, N> > {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
@ -310,6 +328,7 @@ struct Sigmoid<Array<T, N> > {
|
||||
template <int N>
|
||||
struct Sigmoid<Array<half_t, N>> {
|
||||
using T = half_t;
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||
@ -338,6 +357,8 @@ struct Sigmoid<Array<half_t, N>> {
|
||||
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
|
||||
template <typename T>
|
||||
struct SiLu {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value) const {
|
||||
Sigmoid<T> sigmoid;
|
||||
@ -347,6 +368,8 @@ struct SiLu {
|
||||
|
||||
template <typename T, int N>
|
||||
struct SiLu<Array<T, N>> {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Sigmoid<Array<T, N>> sigmoid_op;
|
||||
@ -362,6 +385,8 @@ struct SiLu<Array<T, N>> {
|
||||
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
|
||||
template <typename T>
|
||||
struct HardSwish {
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &x) const {
|
||||
minimum<T> mn;
|
||||
@ -374,6 +399,7 @@ struct HardSwish {
|
||||
template <>
|
||||
struct HardSwish<float> {
|
||||
using T = float;
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &x) const {
|
||||
@ -386,6 +412,8 @@ struct HardSwish<float> {
|
||||
|
||||
template <typename T, int N>
|
||||
struct HardSwish<Array<T, N> > {
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
@ -403,6 +431,7 @@ struct HardSwish<Array<T, N> > {
|
||||
template <int N>
|
||||
struct HardSwish<Array<half_t, N> > {
|
||||
using T = half_t;
|
||||
static const bool kIsHeavy = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
@ -427,6 +456,8 @@ struct HardSwish<Array<half_t, N> > {
|
||||
// GELU operator
|
||||
template <typename T>
|
||||
struct GELU {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &value) const {
|
||||
return T(cutlass::constants::half<T>() * value *
|
||||
@ -436,6 +467,8 @@ struct GELU {
|
||||
|
||||
template <>
|
||||
struct GELU<float> {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &value) const {
|
||||
return cutlass::constants::half<float>() * value *
|
||||
@ -445,6 +478,8 @@ struct GELU<float> {
|
||||
|
||||
template <>
|
||||
struct GELU<double> {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
double operator()(double const &value) const {
|
||||
return cutlass::constants::half<double>() * value *
|
||||
@ -454,6 +489,8 @@ struct GELU<double> {
|
||||
|
||||
template <typename T, int N>
|
||||
struct GELU<Array<T, N> > {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
@ -474,7 +511,8 @@ using ScaledGELU = Scale<GELU<T>>;
|
||||
// GELU operator implemented using the Taylor series approximation
|
||||
template <typename T>
|
||||
struct GELU_taylor {
|
||||
static const bool kIsHeavy=true;
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &z) const {
|
||||
|
||||
@ -488,7 +526,8 @@ struct GELU_taylor {
|
||||
|
||||
template <int N>
|
||||
struct GELU_taylor<Array<half_t, N> > {
|
||||
static const bool kIsHeavy=true;
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<half_t, N> operator()(Array<half_t, N> const &z) const {
|
||||
|
||||
@ -514,7 +553,8 @@ struct GELU_taylor<Array<half_t, N> > {
|
||||
|
||||
template <typename T, int N>
|
||||
struct GELU_taylor<Array<T, N> > {
|
||||
static const bool kIsHeavy=true;
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||
Array<T, N> y;
|
||||
@ -536,6 +576,8 @@ using ScaledGELU_taylor = Scale<GELU_taylor<T>>;
|
||||
/// z is computed from the forward pass.
|
||||
template <typename T>
|
||||
struct dGELU {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &d_t, T const &z) const {
|
||||
|
||||
@ -554,6 +596,8 @@ struct dGELU {
|
||||
|
||||
template <typename T, int N>
|
||||
struct dGELU<Array<T, N> > {
|
||||
static const bool kIsHeavy = true;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &d_t, Array<T, N> const &z) const {
|
||||
Array<T, N> y;
|
||||
@ -568,6 +612,45 @@ struct dGELU<Array<T, N> > {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct dReLU {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const& d_t, bool d_relu) const {
|
||||
return d_relu ? d_t : T(0);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const& d_t, uint1b_t d_relu) const {
|
||||
return operator()(d_t, static_cast<bool>(d_relu));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct dReLU<Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const& d_t, bool const (&d_relu)[N]) const {
|
||||
Array<T, N> y;
|
||||
dReLU<T> relu_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
y[i] = relu_op(d_t[i], d_relu[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const& d_t, Array<uint1b_t, N> const& d_relu) const {
|
||||
UnpackPredicates<N> unpack_op;
|
||||
|
||||
bool preds[N];
|
||||
unpack_op(preds, d_relu);
|
||||
|
||||
return operator()(d_t, preds);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
|
Loading…
Reference in New Issue
Block a user