Refactor GELU and Sigmoid epilogue to use a common template (and add SiLu, Hardswish epilogue) (#379)

* Support half precision sigmoid activation

* introduce a vectorized variant using fast_tanh

* refactored sigmoid using the new interface

* refactored gelu

* add silu activation

* add hardswish

* remove sigmoid for now

* add description to silu and hardswish, and other doc update

* Do not ignore Round

* use constant N

* Set isHeavy = true in sigmoid and silu epilogue
This commit is contained in:
masahi 2021-12-19 04:58:15 +09:00 committed by GitHub
parent ec4f7e5194
commit 0dc3ba60b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 403 additions and 309 deletions

View File

@ -126,6 +126,61 @@ struct Sigmoid<Array<T, N> > {
}
};
// SiLu (swish) operator introduced by Elfwing et al. in the following paper
// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017)
// https://arxiv.org/pdf/1702.03118.pdf
// It is used in EfficientNet and YOLOv5, for example.
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
template <typename T>
struct SiLu {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return scalar * Sigmoid<T>(scalar);
}
};
template <typename T, int N>
struct SiLu<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Sigmoid<Array<T, N>> sigmoid_op;
multiplies<Array<T, N>> mul;
return mul(rhs, sigmoid_op(rhs));
}
};
// Hardswish operator introduced by Howard et al. in the following paper
// "Searching for MobileNetV3" (2019)
// https://arxiv.org/pdf/1905.02244.pdf
// It is used in models based on MobilenetNetV3.
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
template <typename T>
struct HardSwish {
CUTLASS_HOST_DEVICE
T operator()(T const &x) const {
minimum<T> mn;
maximum<T> mx;
T relu6 = mn(mx(x + T(3), T(0)), T(6));
return x * (relu6 / T(6));
}
};
template <typename T, int N>
struct HardSwish<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> y;
HardSwish<T> hardswish_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = hardswish_op(rhs[i]);
}
return y;
}
};
//
// GELU function definitions implemented as described by
// Hendrycks, D., and Gimpel, K. in
@ -189,7 +244,7 @@ struct GELU_taylor {
T k0 = T(0.7978845608028654);
T k1 = T(0.044715);
return T(cutlass::constants::half<T>() * z *
return T(cutlass::constants::half<T>() * z *
(cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
}
};
@ -199,10 +254,10 @@ struct GELU_taylor<Array<half_t, N> > {
static const bool kIsHeavy=true;
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &z) const {
using T = half_t;
Array<half_t, N> y;
half_t k0 = half_t(0.7978845608028654);
half_t k1 = half_t(0.044715);
@ -250,7 +305,7 @@ struct dGELU {
T tanh_out = fast_tanh(k0 * z * (1 + k1 * z * z));
T ff = constants::half<T>() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) +
T ff = constants::half<T>() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) +
constants::half<T>() * (1 + tanh_out);
return ff * d_t;

View File

@ -29,12 +29,8 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -44,9 +40,9 @@ namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
/// Applies a linear combination operator followed by the GELU activation to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
/// D = gelu(alpha * accumulator + beta * source + uniform)
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
@ -57,153 +53,9 @@ template <
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationGELU {
public:
using LinearCombinationGELU = LinearCombinationGeneric<GELU, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static bool const kIsHeavy = true;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationGELU(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_ != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
CUTLASS_UNUSED(k_partition_count);
}
/// Computes: D = gelu( alpha * accumulator + beta * source )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = gelu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes: D = gelu( alpha * accumulator )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
intermediate = gelu(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,209 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator followed by an activation function to an array of elements.
///
/// D = activation(alpha * accumulator + beta * source + uniform)
///
template <
template<typename T> class ActivationFunctor,
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
bool IsHeavy = false
>
class LinearCombinationGeneric {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static bool const kIsHeavy = IsHeavy;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationGeneric(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_ != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
ActivationFunctor<ComputeFragment> activation;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = activation(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_accumulator;
ActivationFunctor<ComputeFragment> activation;
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
intermediate = activation(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass

View File

@ -0,0 +1,62 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with HardSwish operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator followed by the HardSwish activation to an array of elements.
///
/// D = hardswish(alpha * accumulator + beta * source + uniform)
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
using LinearCombinationHardSwish = LinearCombinationGeneric<HardSwish, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round>;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass

View File

@ -23,18 +23,14 @@
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues.
\brief Functor performing linear combination with Sigmoid operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -44,9 +40,9 @@ namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
/// Applies a linear combination operator followed by the Sigmoid activation, to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
/// D = sigmoid(alpha * accumulator + beta * source + uniform)
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
@ -57,150 +53,8 @@ template <
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationSigmoid {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationSigmoid(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_ != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
Sigmoid<ComputeFragment> sigmoid;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = sigmoid(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_accumulator;
Sigmoid<ComputeFragment> sigmoid;
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
intermediate = sigmoid(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
using LinearCombinationSigmoid = LinearCombinationGeneric<Sigmoid, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread

View File

@ -0,0 +1,62 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with SiLU operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator folllowed by the SiLU activation to an array of elements.
///
/// D = silu(alpha * accumulator + beta * source + uniform)
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data to store
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
using LinearCombinationSilu = LinearCombinationGeneric<SiLu, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass