set kIsHeavy member variables (#1012)
* set kIsHeavy member variables * correct kIsHeavy value for Tanh * set kIsHeavy=false for HardSwish --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
61a38f83dc
commit
5f13dcad78
@ -37,6 +37,7 @@
|
|||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "cutlass/numeric_types.h"
|
#include "cutlass/numeric_types.h"
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
#include "cutlass/constants.h"
|
#include "cutlass/constants.h"
|
||||||
#include "cutlass/complex.h"
|
#include "cutlass/complex.h"
|
||||||
#include "cutlass/array.h"
|
#include "cutlass/array.h"
|
||||||
@ -54,7 +55,7 @@ namespace thread {
|
|||||||
// Identity operator
|
// Identity operator
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Identity {
|
struct Identity {
|
||||||
static const bool kIsHeavy=false;
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T value) const {
|
T operator()(T value) const {
|
||||||
@ -128,7 +129,8 @@ struct Scale<Activation<T>> {
|
|||||||
/// Always put threshold in the right hand side of max to propagate NaN.
|
/// Always put threshold in the right hand side of max to propagate NaN.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ReLu {
|
struct ReLu {
|
||||||
static const bool kIsHeavy=false;
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const & threshold, T value) const {
|
T operator()(T const & threshold, T value) const {
|
||||||
maximum<T> mx;
|
maximum<T> mx;
|
||||||
@ -149,7 +151,8 @@ using ReLU = ReLu<T>;
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct ReLu<Array<T, N>> {
|
struct ReLu<Array<T, N>> {
|
||||||
static const bool kIsHeavy=false;
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
|
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
|
||||||
maximum<Array<T, N>> mx;
|
maximum<Array<T, N>> mx;
|
||||||
@ -207,6 +210,9 @@ struct Clamp<Array<T,N>> {
|
|||||||
// Leaky Relu operator
|
// Leaky Relu operator
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LeakyReLU {
|
struct LeakyReLU {
|
||||||
|
|
||||||
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
T leaky_alpha = T(0);
|
T leaky_alpha = T(0);
|
||||||
};
|
};
|
||||||
@ -225,6 +231,9 @@ struct LeakyReLU {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct LeakyReLU<Array<T, N> > {
|
struct LeakyReLU<Array<T, N> > {
|
||||||
|
|
||||||
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
using Arguments = typename LeakyReLU<T>::Arguments;
|
using Arguments = typename LeakyReLU<T>::Arguments;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
@ -249,6 +258,8 @@ struct LeakyReLU<Array<T, N> > {
|
|||||||
// Tanh operator
|
// Tanh operator
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Tanh {
|
struct Tanh {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &value) const {
|
T operator()(T const &value) const {
|
||||||
return fast_tanh(value);
|
return fast_tanh(value);
|
||||||
@ -257,6 +268,8 @@ struct Tanh {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct Tanh<Array<T, N> > {
|
struct Tanh<Array<T, N> > {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
Array<T, N> y;
|
Array<T, N> y;
|
||||||
@ -274,6 +287,7 @@ struct Tanh<Array<T, N> > {
|
|||||||
template <int N>
|
template <int N>
|
||||||
struct Tanh<Array<half_t, N>> {
|
struct Tanh<Array<half_t, N>> {
|
||||||
using T = half_t;
|
using T = half_t;
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const& z) const {
|
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||||
@ -285,6 +299,8 @@ struct Tanh<Array<half_t, N>> {
|
|||||||
// Sigmoid operator
|
// Sigmoid operator
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Sigmoid {
|
struct Sigmoid {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &value) const {
|
T operator()(T const &value) const {
|
||||||
return T(1) / (T(1) + fast_exp(-value));
|
return T(1) / (T(1) + fast_exp(-value));
|
||||||
@ -293,6 +309,8 @@ struct Sigmoid {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct Sigmoid<Array<T, N> > {
|
struct Sigmoid<Array<T, N> > {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
Array<T, N> y;
|
Array<T, N> y;
|
||||||
@ -310,6 +328,7 @@ struct Sigmoid<Array<T, N> > {
|
|||||||
template <int N>
|
template <int N>
|
||||||
struct Sigmoid<Array<half_t, N>> {
|
struct Sigmoid<Array<half_t, N>> {
|
||||||
using T = half_t;
|
using T = half_t;
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const& z) const {
|
Array<T, N> operator()(Array<T, N> const& z) const {
|
||||||
@ -338,6 +357,8 @@ struct Sigmoid<Array<half_t, N>> {
|
|||||||
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
|
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct SiLu {
|
struct SiLu {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &value) const {
|
T operator()(T const &value) const {
|
||||||
Sigmoid<T> sigmoid;
|
Sigmoid<T> sigmoid;
|
||||||
@ -347,6 +368,8 @@ struct SiLu {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct SiLu<Array<T, N>> {
|
struct SiLu<Array<T, N>> {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
Sigmoid<Array<T, N>> sigmoid_op;
|
Sigmoid<Array<T, N>> sigmoid_op;
|
||||||
@ -362,6 +385,8 @@ struct SiLu<Array<T, N>> {
|
|||||||
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
|
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct HardSwish {
|
struct HardSwish {
|
||||||
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &x) const {
|
T operator()(T const &x) const {
|
||||||
minimum<T> mn;
|
minimum<T> mn;
|
||||||
@ -374,6 +399,7 @@ struct HardSwish {
|
|||||||
template <>
|
template <>
|
||||||
struct HardSwish<float> {
|
struct HardSwish<float> {
|
||||||
using T = float;
|
using T = float;
|
||||||
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &x) const {
|
T operator()(T const &x) const {
|
||||||
@ -386,6 +412,8 @@ struct HardSwish<float> {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct HardSwish<Array<T, N> > {
|
struct HardSwish<Array<T, N> > {
|
||||||
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
Array<T, N> y;
|
Array<T, N> y;
|
||||||
@ -403,6 +431,7 @@ struct HardSwish<Array<T, N> > {
|
|||||||
template <int N>
|
template <int N>
|
||||||
struct HardSwish<Array<half_t, N> > {
|
struct HardSwish<Array<half_t, N> > {
|
||||||
using T = half_t;
|
using T = half_t;
|
||||||
|
static const bool kIsHeavy = false;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
@ -427,6 +456,8 @@ struct HardSwish<Array<half_t, N> > {
|
|||||||
// GELU operator
|
// GELU operator
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct GELU {
|
struct GELU {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &value) const {
|
T operator()(T const &value) const {
|
||||||
return T(cutlass::constants::half<T>() * value *
|
return T(cutlass::constants::half<T>() * value *
|
||||||
@ -436,6 +467,8 @@ struct GELU {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct GELU<float> {
|
struct GELU<float> {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
float operator()(float const &value) const {
|
float operator()(float const &value) const {
|
||||||
return cutlass::constants::half<float>() * value *
|
return cutlass::constants::half<float>() * value *
|
||||||
@ -445,6 +478,8 @@ struct GELU<float> {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct GELU<double> {
|
struct GELU<double> {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
double operator()(double const &value) const {
|
double operator()(double const &value) const {
|
||||||
return cutlass::constants::half<double>() * value *
|
return cutlass::constants::half<double>() * value *
|
||||||
@ -454,6 +489,8 @@ struct GELU<double> {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct GELU<Array<T, N> > {
|
struct GELU<Array<T, N> > {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
Array<T, N> y;
|
Array<T, N> y;
|
||||||
@ -474,7 +511,8 @@ using ScaledGELU = Scale<GELU<T>>;
|
|||||||
// GELU operator implemented using the Taylor series approximation
|
// GELU operator implemented using the Taylor series approximation
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct GELU_taylor {
|
struct GELU_taylor {
|
||||||
static const bool kIsHeavy=true;
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &z) const {
|
T operator()(T const &z) const {
|
||||||
|
|
||||||
@ -488,7 +526,8 @@ struct GELU_taylor {
|
|||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
struct GELU_taylor<Array<half_t, N> > {
|
struct GELU_taylor<Array<half_t, N> > {
|
||||||
static const bool kIsHeavy=true;
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<half_t, N> operator()(Array<half_t, N> const &z) const {
|
Array<half_t, N> operator()(Array<half_t, N> const &z) const {
|
||||||
|
|
||||||
@ -514,7 +553,8 @@ struct GELU_taylor<Array<half_t, N> > {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct GELU_taylor<Array<T, N> > {
|
struct GELU_taylor<Array<T, N> > {
|
||||||
static const bool kIsHeavy=true;
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &value) const {
|
Array<T, N> operator()(Array<T, N> const &value) const {
|
||||||
Array<T, N> y;
|
Array<T, N> y;
|
||||||
@ -536,6 +576,8 @@ using ScaledGELU_taylor = Scale<GELU_taylor<T>>;
|
|||||||
/// z is computed from the forward pass.
|
/// z is computed from the forward pass.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct dGELU {
|
struct dGELU {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
T operator()(T const &d_t, T const &z) const {
|
T operator()(T const &d_t, T const &z) const {
|
||||||
|
|
||||||
@ -554,6 +596,8 @@ struct dGELU {
|
|||||||
|
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
struct dGELU<Array<T, N> > {
|
struct dGELU<Array<T, N> > {
|
||||||
|
static const bool kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Array<T, N> operator()(Array<T, N> const &d_t, Array<T, N> const &z) const {
|
Array<T, N> operator()(Array<T, N> const &d_t, Array<T, N> const &z) const {
|
||||||
Array<T, N> y;
|
Array<T, N> y;
|
||||||
@ -568,6 +612,45 @@ struct dGELU<Array<T, N> > {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct dReLU {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
T operator()(T const& d_t, bool d_relu) const {
|
||||||
|
return d_relu ? d_t : T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
T operator()(T const& d_t, uint1b_t d_relu) const {
|
||||||
|
return operator()(d_t, static_cast<bool>(d_relu));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
struct dReLU<Array<T, N>> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const& d_t, bool const (&d_relu)[N]) const {
|
||||||
|
Array<T, N> y;
|
||||||
|
dReLU<T> relu_op;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
y[i] = relu_op(d_t[i], d_relu[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Array<T, N> operator()(Array<T, N> const& d_t, Array<uint1b_t, N> const& d_relu) const {
|
||||||
|
UnpackPredicates<N> unpack_op;
|
||||||
|
|
||||||
|
bool preds[N];
|
||||||
|
unpack_op(preds, d_relu);
|
||||||
|
|
||||||
|
return operator()(d_t, preds);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace thread
|
} // namespace thread
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user