cutlass/cutlass/gemm/linear_scaling.h
2018-09-18 16:58:03 -07:00

170 lines
5.8 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
*/
#pragma once
#include "cutlass/fragment_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
CUTLASS_DEVICE bool is_zero(T x) {
return x == T(0);
}
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor to compute linear combination of fragments
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
struct LinearScaling {
// The scalar.
typedef Scalar_ Scalar;
// The accumulator Type
typedef typename FragmentMultiplyAdd_::ScalarAccum ScalarAccum;
// The adapater.
typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
/// The parameters.
struct Params {
/// The alpha/beta scaling params.
Scalar alpha, beta;
//
// Methods
//
// Constructor
CUTLASS_HOST_DEVICE
Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {}
/// Initialize the parameters
CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) {
alpha = _alpha;
beta = _beta;
return 0;
}
/// Initialize the parameters.
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
alpha = desc.alpha;
beta = desc.beta;
return 0;
}
};
//
// Data members
//
Params params;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE LinearScaling() { }
/// Ctor.
CUTLASS_DEVICE LinearScaling(Params const& _params) : params(_params) {}
/// Method to determine whether the source accumulator matrix C is ever needed. This method
/// may always safely return true, though better performance is possible if the source accumulator
/// matrix is never loaded unnecessarily.
CUTLASS_DEVICE
bool source_required() const {
return !is_zero(params.beta);
}
/// Evaluate the functor.
template <typename FragmentA_, typename FragmentB_>
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
FragmentMultiplyAdd mad;
mad.multiply(params.alpha, accum, output);
}
/// Evaluate the functor, without using fragment in the API
template <typename ScalarAccum, typename ScalarOutput, int size>
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output) {
Fragment<ScalarAccum, size> FragAccum;
Fragment<ScalarOutput, size> FragOutput;
#pragma unroll
for (int i = 0; i < size; i++) {
FragAccum[i] = accum[i];
FragOutput[i] = output[i];
}
evaluate(FragAccum, FragOutput);
#pragma unroll
for (int i = 0; i < size; i++) {
output[i] = FragOutput[i];
}
}
/// Evaluate the functor.
template <typename FragmentA_, typename FragmentB_>
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
FragmentMultiplyAdd mad;
FragmentB_ tmp;
mad.multiply(params.beta, old, tmp);
mad.multiply_add(params.alpha, accum, tmp, output);
}
/// Evaluate the functor, without using fragment in the API
template <typename ScalarAccum, typename ScalarOutput, int size>
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output) {
Fragment<ScalarAccum, size> FragAccum;
Fragment<ScalarOutput, size> FragOutput;
Fragment<ScalarOutput, size> FragOld;
#pragma unroll
for (int i = 0; i < size; i++) {
FragAccum[i] = accum[i];
FragOutput[i] = output[i];
FragOld[i] = old[i];
}
evaluate(FragAccum, FragOld, FragOutput);
#pragma unroll
for (int i = 0; i < size; i++) {
output[i] = FragOutput[i];
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass