271 lines
8.6 KiB
C++
271 lines
8.6 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
/*! \file
|
|
\brief Functor performing linear combination operations used by epilogues.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/array.h"
|
|
#include "cutlass/functional.h"
|
|
#include "cutlass/numeric_conversion.h"
|
|
|
|
#include "cutlass/epilogue/thread/activation.h"
|
|
#include "cutlass/epilogue/thread/scale_type.h"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace epilogue {
|
|
namespace thread {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// This base class is meant to define the concept required of the
|
|
/// EpilogueWithBroadcast::OutputOp
|
|
template <
|
|
typename ElementC_,
|
|
typename ElementAccumulator_,
|
|
typename ElementCompute_,
|
|
typename ElementZ_,
|
|
typename ElementT_,
|
|
int ElementsPerAccess,
|
|
typename ElementwiseOp_ = Identity<ElementCompute_>,
|
|
typename BinaryOp_ = plus<ElementCompute_>,
|
|
bool StoreT_ = true,
|
|
typename ElementVector_ = ElementC_
|
|
>
|
|
class LinearCombinationBiasElementwise {
|
|
public:
|
|
|
|
using ElementOutput = ElementC_;
|
|
using ElementC = ElementC_;
|
|
using ElementAccumulator = ElementAccumulator_;
|
|
using ElementCompute = ElementCompute_;
|
|
using ElementZ = ElementZ_;
|
|
using ElementT = ElementT_;
|
|
using ElementVector = ElementVector_;
|
|
static int const kElementsPerAccess = ElementsPerAccess;
|
|
static int const kCount = kElementsPerAccess;
|
|
|
|
using ElementwiseOp = ElementwiseOp_;
|
|
using BinaryOp = BinaryOp_;
|
|
|
|
// Indicates that this epilogue applies only one binary operation
|
|
static bool const kIsSingleSource = true;
|
|
|
|
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
|
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
|
|
using FragmentC = Array<ElementOutput, kElementsPerAccess>;
|
|
using FragmentZ = Array<ElementZ, kElementsPerAccess>;
|
|
using FragmentT = Array<ElementT, kElementsPerAccess>;
|
|
|
|
// Definitions needed for collective epilogue
|
|
using FragmentSource = FragmentC;
|
|
using FragmentOutput = FragmentZ;
|
|
using ElementBias = ElementVector;
|
|
using FragmentBias = FragmentCompute;
|
|
using ActivationFunctor = ElementwiseOp;
|
|
static const ScaleType::Kind kScale = ScaleType::Default;
|
|
|
|
static bool const kIsHeavy = ElementwiseOp::kIsHeavy;
|
|
|
|
/// If true, the 'Z' tensor is stored
|
|
static bool const kStoreZ = true;
|
|
|
|
/// If true, the 'T' tensor is stored
|
|
static bool const kStoreT = StoreT_;
|
|
|
|
/// Host-constructable parameters structure
|
|
struct Params {
|
|
|
|
ElementCompute alpha; ///< scales accumulators
|
|
ElementCompute beta; ///< scales source tensor
|
|
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
|
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params():
|
|
alpha(ElementCompute(1)),
|
|
beta(ElementCompute(0)),
|
|
alpha_ptr(nullptr),
|
|
beta_ptr(nullptr) { }
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(
|
|
ElementCompute alpha,
|
|
ElementCompute beta
|
|
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
|
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(
|
|
ElementCompute alpha
|
|
): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
|
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(
|
|
ElementCompute const *alpha_ptr,
|
|
ElementCompute const *beta_ptr
|
|
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
|
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
Params(
|
|
ElementCompute const *alpha_ptr
|
|
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) {
|
|
|
|
}
|
|
};
|
|
|
|
private:
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
ElementCompute alpha_;
|
|
ElementCompute beta_;
|
|
bool skip_elementwise_;
|
|
|
|
public:
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
/// Constructor from Params
|
|
CUTLASS_HOST_DEVICE
|
|
LinearCombinationBiasElementwise(Params const ¶ms) {
|
|
|
|
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
|
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
|
skip_elementwise_ = false;
|
|
}
|
|
|
|
/// Returns true if source is needed
|
|
CUTLASS_HOST_DEVICE
|
|
bool is_source_needed() const {
|
|
return beta_ != ElementCompute(0);
|
|
}
|
|
|
|
/// Functionally required for serial reduction in the epilogue
|
|
CUTLASS_HOST_DEVICE
|
|
void set_k_partition(int k_partition, int k_partition_count) {
|
|
if (k_partition) {
|
|
beta_ = ElementCompute(1);
|
|
}
|
|
|
|
if (k_partition != k_partition_count - 1) {
|
|
skip_elementwise_ = true;
|
|
}
|
|
}
|
|
|
|
/// Applies the operation when is_source_needed() is true
|
|
CUTLASS_HOST_DEVICE
|
|
void operator()(
|
|
FragmentZ &frag_Z,
|
|
FragmentT &frag_T,
|
|
FragmentAccumulator const &AB,
|
|
FragmentC const &frag_C,
|
|
FragmentCompute const &V) const {
|
|
|
|
ElementwiseOp elementwise_op;
|
|
BinaryOp binary_op;
|
|
|
|
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
|
FragmentCompute tmp_C = NumericArrayConverter<ElementCompute, ElementC, kElementsPerAccess>()(frag_C);
|
|
FragmentCompute result_Z;
|
|
FragmentCompute result_T;
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < kElementsPerAccess; ++i) {
|
|
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
|
|
result_T[i] = z;
|
|
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
|
}
|
|
|
|
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
|
frag_Z = convert_z(result_Z);
|
|
|
|
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
|
frag_T = convert_t(result_T);
|
|
}
|
|
|
|
/// Applies the operation when is_source_needed() is false
|
|
CUTLASS_HOST_DEVICE
|
|
void operator()(
|
|
FragmentZ &frag_Z,
|
|
FragmentT &frag_T,
|
|
FragmentAccumulator const &AB,
|
|
FragmentCompute const &V) const {
|
|
|
|
ElementwiseOp elementwise_op;
|
|
BinaryOp binary_op;
|
|
|
|
FragmentCompute tmp_Accum = NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
|
|
FragmentCompute result_Z;
|
|
FragmentCompute result_T;
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < kElementsPerAccess; ++i) {
|
|
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
|
|
result_T[i] = z;
|
|
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
|
}
|
|
|
|
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
|
frag_Z = convert_z(result_Z);
|
|
|
|
NumericArrayConverter<ElementT, ElementCompute, kElementsPerAccess> convert_t;
|
|
frag_T = convert_t(result_T);
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace thread
|
|
} // namespace epilogue
|
|
} // namespace cutlass
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|