31 #ifdef CUTLASS_USE_WMMA_API 45 typename AccumulatorsPerWarp_,
46 typename InstructionShape_>
47 struct WmmaGemmMultiplyAdd {
49 typedef InstructionShape_ InstructionShape;
51 typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
53 typedef AccumulatorsPerWarp_ AccumulatorsPerWarp;
55 typedef ScalarA_ ScalarA;
57 typedef ScalarB_ ScalarB;
59 typedef ScalarC_ ScalarC;
64 typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
66 typedef Fragment<ElementA, Iterations::kW> FragmentA;
69 typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
71 typedef Fragment<ElementB, Iterations::kH> FragmentB;
74 typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
76 typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
79 CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
82 CUTLASS_DEVICE
void multiply_add(FragmentA
const& a,
84 Accumulators
const& c,
86 for (
int j = 0; j < Iterations::kH; ++j) {
87 for (
int i = 0; i < Iterations::kW; ++i) {
89 ElementA
const& elt_a = a[i];
90 ElementB
const& elt_b = b[j];
91 ElementC
const& elt_c = c[j * Iterations::kW + i];
94 ElementC& elt_d = d[j * Iterations::kW + i];
97 nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
108 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
Kind
Definition: matrix_traits.h:36
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...