diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index c30a8209..ca1f72c2 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -98,6 +98,32 @@ struct ReLu> { } }; +// Leaky Relu operator +template +struct LeakyReLU { + CUTLASS_HOST_DEVICE + T operator()(T const &value, T const & alpha_recip) const { + T res = value > T(0) ? value : value * alpha_recip; + return res; + } +}; + +template +struct LeakyReLU > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, T const & alpha_recip) const { + Array y; + LeakyReLU leaky_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(rhs.size()); ++i) { + y[i] = leaky_op(rhs[i], alpha_recip); + } + + return y; + } +}; + // Sigmoid operator template struct Sigmoid { diff --git a/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h new file mode 100644 index 00000000..95150ef4 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h @@ -0,0 +1,210 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationLeakyRelu { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta_bias; ///< scales bias tensor + ElementCompute leaky_alpha; ///< leaky_alpha + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta_bias(ElementCompute(0)), + leaky_alpha(ElementCompute(1)) + { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta_bias, + ElementCompute leaky_alpha = ElementCompute(1) + ): alpha(alpha), beta_bias(beta_bias), leaky_alpha(leaky_alpha) { + + } + + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_bias_; + ElementCompute leaky_alpha_recip_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationLeakyRelu(Params const ¶ms) { + alpha_ = (params.alpha); + beta_bias_ = (params.beta_bias); + leaky_alpha_recip_ = (ElementCompute(params.leaky_alpha)); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_bias_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition) { + if (k_partition) { + beta_bias_ = ElementCompute(1); + } + } + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_bias_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + LeakyReLU leakyrelu; + + intermediate = mul_add_source(beta_bias_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + // Compute threshold optionally + intermediate = leakyrelu(intermediate, leaky_alpha_recip_); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_accumulator; + LeakyReLU leakyrelu; + //printf("in doing with bias"); + intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + // Compute threshold optionally + intermediate = leakyrelu(intermediate, leaky_alpha_recip_); + + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////