set kIsHeavy member variables (#1012)

* set kIsHeavy member variables

* correct kIsHeavy value for Tanh

* set kIsHeavy=false for HardSwish

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Fabian Schuetze 2023-10-04 18:38:36 +02:00 committed by GitHub
parent 61a38f83dc
commit 5f13dcad78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -37,6 +37,7 @@
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/constants.h" #include "cutlass/constants.h"
#include "cutlass/complex.h" #include "cutlass/complex.h"
#include "cutlass/array.h" #include "cutlass/array.h"
@ -129,6 +130,7 @@ struct Scale<Activation<T>> {
template <typename T> template <typename T>
struct ReLu { struct ReLu {
static const bool kIsHeavy = false; static const bool kIsHeavy = false;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const & threshold, T value) const { T operator()(T const & threshold, T value) const {
maximum<T> mx; maximum<T> mx;
@ -150,6 +152,7 @@ using ReLU = ReLu<T>;
template <typename T, int N> template <typename T, int N>
struct ReLu<Array<T, N>> { struct ReLu<Array<T, N>> {
static const bool kIsHeavy = false; static const bool kIsHeavy = false;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const { Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
maximum<Array<T, N>> mx; maximum<Array<T, N>> mx;
@ -207,6 +210,9 @@ struct Clamp<Array<T,N>> {
// Leaky Relu operator // Leaky Relu operator
template <typename T> template <typename T>
struct LeakyReLU { struct LeakyReLU {
static const bool kIsHeavy = false;
struct Arguments { struct Arguments {
T leaky_alpha = T(0); T leaky_alpha = T(0);
}; };
@ -225,6 +231,9 @@ struct LeakyReLU {
template <typename T, int N> template <typename T, int N>
struct LeakyReLU<Array<T, N> > { struct LeakyReLU<Array<T, N> > {
static const bool kIsHeavy = false;
using Arguments = typename LeakyReLU<T>::Arguments; using Arguments = typename LeakyReLU<T>::Arguments;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
@ -249,6 +258,8 @@ struct LeakyReLU<Array<T, N> > {
// Tanh operator // Tanh operator
template <typename T> template <typename T>
struct Tanh { struct Tanh {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &value) const { T operator()(T const &value) const {
return fast_tanh(value); return fast_tanh(value);
@ -257,6 +268,8 @@ struct Tanh {
template <typename T, int N> template <typename T, int N>
struct Tanh<Array<T, N> > { struct Tanh<Array<T, N> > {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const { Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y; Array<T, N> y;
@ -274,6 +287,7 @@ struct Tanh<Array<T, N> > {
template <int N> template <int N>
struct Tanh<Array<half_t, N>> { struct Tanh<Array<half_t, N>> {
using T = half_t; using T = half_t;
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& z) const { Array<T, N> operator()(Array<T, N> const& z) const {
@ -285,6 +299,8 @@ struct Tanh<Array<half_t, N>> {
// Sigmoid operator // Sigmoid operator
template <typename T> template <typename T>
struct Sigmoid { struct Sigmoid {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &value) const { T operator()(T const &value) const {
return T(1) / (T(1) + fast_exp(-value)); return T(1) / (T(1) + fast_exp(-value));
@ -293,6 +309,8 @@ struct Sigmoid {
template <typename T, int N> template <typename T, int N>
struct Sigmoid<Array<T, N> > { struct Sigmoid<Array<T, N> > {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const { Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y; Array<T, N> y;
@ -310,6 +328,7 @@ struct Sigmoid<Array<T, N> > {
template <int N> template <int N>
struct Sigmoid<Array<half_t, N>> { struct Sigmoid<Array<half_t, N>> {
using T = half_t; using T = half_t;
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& z) const { Array<T, N> operator()(Array<T, N> const& z) const {
@ -338,6 +357,8 @@ struct Sigmoid<Array<half_t, N>> {
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html // Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
template <typename T> template <typename T>
struct SiLu { struct SiLu {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &value) const { T operator()(T const &value) const {
Sigmoid<T> sigmoid; Sigmoid<T> sigmoid;
@ -347,6 +368,8 @@ struct SiLu {
template <typename T, int N> template <typename T, int N>
struct SiLu<Array<T, N>> { struct SiLu<Array<T, N>> {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const { Array<T, N> operator()(Array<T, N> const &value) const {
Sigmoid<Array<T, N>> sigmoid_op; Sigmoid<Array<T, N>> sigmoid_op;
@ -362,6 +385,8 @@ struct SiLu<Array<T, N>> {
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html // Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
template <typename T> template <typename T>
struct HardSwish { struct HardSwish {
static const bool kIsHeavy = false;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &x) const { T operator()(T const &x) const {
minimum<T> mn; minimum<T> mn;
@ -374,6 +399,7 @@ struct HardSwish {
template <> template <>
struct HardSwish<float> { struct HardSwish<float> {
using T = float; using T = float;
static const bool kIsHeavy = false;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &x) const { T operator()(T const &x) const {
@ -386,6 +412,8 @@ struct HardSwish<float> {
template <typename T, int N> template <typename T, int N>
struct HardSwish<Array<T, N> > { struct HardSwish<Array<T, N> > {
static const bool kIsHeavy = false;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const { Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y; Array<T, N> y;
@ -403,6 +431,7 @@ struct HardSwish<Array<T, N> > {
template <int N> template <int N>
struct HardSwish<Array<half_t, N> > { struct HardSwish<Array<half_t, N> > {
using T = half_t; using T = half_t;
static const bool kIsHeavy = false;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const { Array<T, N> operator()(Array<T, N> const &value) const {
@ -427,6 +456,8 @@ struct HardSwish<Array<half_t, N> > {
// GELU operator // GELU operator
template <typename T> template <typename T>
struct GELU { struct GELU {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &value) const { T operator()(T const &value) const {
return T(cutlass::constants::half<T>() * value * return T(cutlass::constants::half<T>() * value *
@ -436,6 +467,8 @@ struct GELU {
template <> template <>
struct GELU<float> { struct GELU<float> {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
float operator()(float const &value) const { float operator()(float const &value) const {
return cutlass::constants::half<float>() * value * return cutlass::constants::half<float>() * value *
@ -445,6 +478,8 @@ struct GELU<float> {
template <> template <>
struct GELU<double> { struct GELU<double> {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
double operator()(double const &value) const { double operator()(double const &value) const {
return cutlass::constants::half<double>() * value * return cutlass::constants::half<double>() * value *
@ -454,6 +489,8 @@ struct GELU<double> {
template <typename T, int N> template <typename T, int N>
struct GELU<Array<T, N> > { struct GELU<Array<T, N> > {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const { Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y; Array<T, N> y;
@ -475,6 +512,7 @@ using ScaledGELU = Scale<GELU<T>>;
template <typename T> template <typename T>
struct GELU_taylor { struct GELU_taylor {
static const bool kIsHeavy = true; static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &z) const { T operator()(T const &z) const {
@ -489,6 +527,7 @@ struct GELU_taylor {
template <int N> template <int N>
struct GELU_taylor<Array<half_t, N> > { struct GELU_taylor<Array<half_t, N> > {
static const bool kIsHeavy = true; static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &z) const { Array<half_t, N> operator()(Array<half_t, N> const &z) const {
@ -515,6 +554,7 @@ struct GELU_taylor<Array<half_t, N> > {
template <typename T, int N> template <typename T, int N>
struct GELU_taylor<Array<T, N> > { struct GELU_taylor<Array<T, N> > {
static const bool kIsHeavy = true; static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &value) const { Array<T, N> operator()(Array<T, N> const &value) const {
Array<T, N> y; Array<T, N> y;
@ -536,6 +576,8 @@ using ScaledGELU_taylor = Scale<GELU_taylor<T>>;
/// z is computed from the forward pass. /// z is computed from the forward pass.
template <typename T> template <typename T>
struct dGELU { struct dGELU {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
T operator()(T const &d_t, T const &z) const { T operator()(T const &d_t, T const &z) const {
@ -554,6 +596,8 @@ struct dGELU {
template <typename T, int N> template <typename T, int N>
struct dGELU<Array<T, N> > { struct dGELU<Array<T, N> > {
static const bool kIsHeavy = true;
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &d_t, Array<T, N> const &z) const { Array<T, N> operator()(Array<T, N> const &d_t, Array<T, N> const &z) const {
Array<T, N> y; Array<T, N> y;
@ -568,6 +612,45 @@ struct dGELU<Array<T, N> > {
} }
}; };
template <typename T>
struct dReLU {
CUTLASS_HOST_DEVICE
T operator()(T const& d_t, bool d_relu) const {
return d_relu ? d_t : T(0);
}
CUTLASS_HOST_DEVICE
T operator()(T const& d_t, uint1b_t d_relu) const {
return operator()(d_t, static_cast<bool>(d_relu));
}
};
template <typename T, int N>
struct dReLU<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& d_t, bool const (&d_relu)[N]) const {
Array<T, N> y;
dReLU<T> relu_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = relu_op(d_t[i], d_relu[i]);
}
return y;
}
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& d_t, Array<uint1b_t, N> const& d_relu) const {
UnpackPredicates<N> unpack_op;
bool preds[N];
unpack_op(preds, d_relu);
return operator()(d_t, preds);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread } // namespace thread