Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_multiply_add.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/wmma_matrix.h>
31 #ifdef CUTLASS_USE_WMMA_API
32 #include <cutlass/fragment.h>
33 
34 namespace cutlass {
35 namespace gemm {
36 
38 
39 template <MatrixLayout::Kind kLayoutA_,
40  typename ScalarA_,
41  MatrixLayout::Kind kLayoutB_,
42  typename ScalarB_,
43  MatrixLayout::Kind kLayoutC_,
44  typename ScalarC_,
45  typename AccumulatorsPerWarp_,
46  typename InstructionShape_>
47 struct WmmaGemmMultiplyAdd {
49  typedef InstructionShape_ InstructionShape;
51  typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
53  typedef AccumulatorsPerWarp_ AccumulatorsPerWarp;
55  typedef ScalarA_ ScalarA;
57  typedef ScalarB_ ScalarB;
59  typedef ScalarC_ ScalarC;
62 
64  typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
66  typedef Fragment<ElementA, Iterations::kW> FragmentA;
67 
69  typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
71  typedef Fragment<ElementB, Iterations::kH> FragmentB;
72 
74  typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
76  typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
77 
79  CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
80 
82  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
83  FragmentB const& b,
84  Accumulators const& c,
85  Accumulators& d) {
86  for (int j = 0; j < Iterations::kH; ++j) {
87  for (int i = 0; i < Iterations::kW; ++i) {
88  // The input elements.
89  ElementA const& elt_a = a[i];
90  ElementB const& elt_b = b[j];
91  ElementC const& elt_c = c[j * Iterations::kW + i];
92 
93  // The output element.
94  ElementC& elt_d = d[j * Iterations::kW + i];
95 
96  // The wmma instruction.
97  nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
98  }
99  }
100  }
101 };
102 
104 
105 } // namespace gemm
106 } // namespace cutlass
107 
108 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: convert.h:33
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
Kind
Definition: matrix_traits.h:36
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...