39 template <
typename Scalar_,
typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_> >
52 template <
typename GemmDesc_>
64 template <
typename Fragment_>
65 CUTLASS_DEVICE
void evaluate(Fragment_
const& accum, Fragment_& output) {
67 mad.multiply(
alpha, accum, output);
71 template <
typename Fragment_>
72 CUTLASS_DEVICE
void evaluate(Fragment_
const& accum, Fragment_
const& old, Fragment_& output) {
75 mad.multiply(
beta, old, tmp);
76 mad.multiply_add(
alpha, accum, tmp, output);
Scalar alpha
The alpha/beta scaling params.
Definition: linear_scaling.h:49
Scalar alpha
The alpha/beta scaling factors.
Definition: linear_scaling.h:80
CUTLASS_DEVICE LinearScaling(Params const ¶ms)
Ctor.
Definition: linear_scaling.h:61
CUTLASS_DEVICE void evaluate(Fragment_ const &accum, Fragment_ const &old, Fragment_ &output)
Evaluate the functor.
Definition: linear_scaling.h:72
Scalar beta
Definition: linear_scaling.h:49
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: linear_scaling.h:53
Scalar beta
Definition: linear_scaling.h:80
Defines multiply-add operations on fragments within a thread.
FragmentMultiplyAdd_ FragmentMultiplyAdd
Definition: linear_scaling.h:44
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE void evaluate(Fragment_ const &accum, Fragment_ &output)
Evaluate the functor.
Definition: linear_scaling.h:65
The parameters.
Definition: linear_scaling.h:47
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:40
Scalar_ Scalar
Definition: linear_scaling.h:42