37 template <
typename Scalar_>
52 template <
typename Fragment_>
53 CUTLASS_DEVICE
void multiply(Scalar_ a, Fragment_
const& b, Fragment_& d) {
54 for (
int j = 0; j < Fragment_::kElements; ++j) {
60 template <
typename Fragment_>
65 for (
int j = 0; j < Fragment_::kElements; ++j) {
66 d[j] = a * b[j] + c[j];
73 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) 89 template <
typename Fragment_>
90 CUTLASS_DEVICE
void multiply(half a, Fragment_
const& b, Fragment_& d) {
91 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 93 __half2
const* b_half2 =
reinterpret_cast<__half2 const*
>(&b[0]);
95 __half2* d_half2 =
reinterpret_cast<__half2*
>(&d[0]);
98 __half2
const a_half2 = __half2half2(a);
100 for (
int i = 0; i < Fragment_::kElements / 2; ++i) {
101 d_half2[i] = __hmul2(a_half2, b_half2[i]);
107 template <
typename Fragment_>
108 CUTLASS_DEVICE
void multiply_add(half a, Fragment_
const& b, Fragment_
const& c, Fragment_& d) {
109 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 111 __half2
const* b_half2 =
reinterpret_cast<__half2 const*
>(&b[0]);
112 __half2
const* c_half2 =
reinterpret_cast<__half2 const*
>(&c[0]);
114 __half2* d_half2 =
reinterpret_cast<__half2*
>(&d[0]);
117 __half2
const a_half2 = __half2half2(a);
119 for (
int i = 0; i < Fragment_::kElements / 2; ++i) {
120 d_half2[i] = __hfma2(a_half2, b_half2[i], c_half2[i]);
Scalar_ ScalarB
The type for B.
Definition: fragment_multiply_add.h:44
CUTLASS_DEVICE void multiply(Scalar_ a, Fragment_ const &b, Fragment_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:53
half ScalarA
The type for A.
Definition: fragment_multiply_add.h:79
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:86
CUTLASS_DEVICE void multiply_add(Scalar_ a, Fragment_ const &b, Fragment_ const &c, Fragment_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:61
half ScalarC
The type for C and D.
Definition: fragment_multiply_add.h:83
CUTLASS_DEVICE void multiply_add(half a, Fragment_ const &b, Fragment_ const &c, Fragment_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:108
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:40
Scalar_ ScalarC
The type for C and D.
Definition: fragment_multiply_add.h:46
Scalar_ ScalarA
The type for A.
Definition: fragment_multiply_add.h:42
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:49
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
CUTLASS_DEVICE void multiply(half a, Fragment_ const &b, Fragment_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:90
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:77
half ScalarB
The type for B.
Definition: fragment_multiply_add.h:81
Definition: fragment_multiply_add.h:38