Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/wmma_matrix.h>
31 #ifdef CUTLASS_USE_WMMA_API
32 
33 #include <cutlass/convert.h>
34 #include <cutlass/gemm/gemm.h>
43 
44 namespace cutlass {
45 namespace gemm {
46 
48 
49 template <
51  MatrixLayout::Kind kLayoutA_,
53  MatrixLayout::Kind kLayoutB_,
55  typename OutputTile_,
57  typename ScalarC_,
59  typename Accumulator_,
61  typename AccumulatorsPerWarp_,
63  typename InstructionShape_,
65  int kScalarsPerLdgA_,
67  int kScalarsPerLdgB_>
68 struct WmmaGemmConfig : public GemmConfig<
70  half,
72  half,
74  ScalarC_,
76  ScalarC_,
78  OutputTile_,
80  WmmaGemmMultiplyAdd<kLayoutA_,
81  half,
82  kLayoutB_,
83  half,
84  MatrixLayout::kColumnMajor,
85  Accumulator_,
86  AccumulatorsPerWarp_,
87  InstructionShape_>,
89  kScalarsPerLdgA_,
91  kScalarsPerLdgA_,
93  8,
95  kScalarsPerLdgB_,
97  kScalarsPerLdgB_,
99  8,
101  16 / sizeof(ScalarC_),
103  16 / sizeof(ScalarC_),
105  16 / sizeof(ScalarC_),
107  1> {};
108 
110 
111 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
112 struct WmmaGemmTileTraitsHelperA {};
113 
115 
116 template <typename GemmConfig_>
117 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
118  : public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
120  typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
121 
123  static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
125  typedef Shape<GemmConfig_::kStages,
126  GemmConfig_::OutputTile::kD,
127  GemmConfig_::OutputTile::kW + kSkew>
128  Tile;
129 
131  typedef WmmaMatrix<GemmOperand::kA,
133  typename Base::MultiplyAddScalar,
134  typename GemmConfig_::InstructionShape>
135  WmmaMatrix;
136 
138  typedef GemmSharedStoreTileAbTraits<
139  // The pointer.
140  typename Base::MultiplyAddScalar,
141  // The tile has size KxM in GEMM's terminology.
142  Tile,
143  // The threads are distributed as warps x 32 (the traits may reorganize).
144  typename Base::GlobalTileTraits::Threads,
145  // The number of scalars per STS (STS.32 or STS.128, etc).
146  GemmConfig_::kScalarsPerStsA>
147  SharedStoreTileTraits;
148 
150  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
152  static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
154  typedef WmmaGemmSharedLoadTileATraits<
155  // The layout of the matrix.
157  // The pointer.
158  typename Base::MultiplyAddScalar,
159  // The output tile size.
160  Tile,
161  // The number of warps.
162  typename GemmConfig_::Warps,
163  // The strides between warps.
164  GemmConfig_::InstructionShape::kW,
165  // The number of iterations to load the data.
166  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
167  // The stride between iterations.
168  Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
169  // The shape of the instruction.
170  typename GemmConfig_::InstructionShape>
171  SharedLoadTileTraits;
172 };
173 
175 
176 template <typename GemmConfig_>
177 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
179  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
180 
182  typedef typename GemmConfig_::ScalarA Scalar;
184  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
185 
187  typedef WmmaMatrix<GemmOperand::kA,
189  MultiplyAddScalar,
190  typename GemmConfig_::InstructionShape>
191  WmmaMatrix;
192 
194  typedef GemmGlobalTileTraits<
195  // That's A.
197  // A is row-major.
199  // The pointer is float const.
200  Scalar const,
201  // The tile has size KxM in GEMM's terminology.
202  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
203  // The threads are distributed as warps x 32 (the traits may reorganize).
204  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
205  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
206  GemmConfig_::kScalarsPerLdgA>
207  GlobalTileTraits;
208 
210  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
212  typedef Shape<GemmConfig_::kStages,
213  GemmConfig_::OutputTile::kW,
214  GemmConfig_::OutputTile::kD + kSkew>
215  Tile;
216 
218  typedef GemmSharedStoreTileAbTraits<
219  // The pointer.
220  MultiplyAddScalar,
221  // The tile has size KxM in GEMM's terminology.
222  Tile,
223  // The threads are distributed as warps x 32 (the traits may reorganize).
224  typename GlobalTileTraits::Threads,
225  // The number of scalars per STS (STS.32 or STS.128, etc).
226  GemmConfig_::kScalarsPerStsA>
227  SharedStoreTileTraits;
228 
230  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
232  typedef WmmaGemmSharedLoadTileATraits<
233  // The layout of the matrix.
235  // The pointer.
236  MultiplyAddScalar,
237  // The tile in shared memory.
238  Tile,
239  // The number of warps.
240  typename GemmConfig_::Warps,
241  // The strides between warps.
242  GemmConfig_::InstructionShape::kW * Tile::kW,
243  // The number of iterations to load the data.
244  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
245  // The stride between iterations.
246  Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
247  // The shape of the instruction.
248  typename GemmConfig_::InstructionShape>
249  SharedLoadTileTraits;
250 };
251 
253 
254 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
255 struct WmmaGemmTileTraitsHelperB {};
256 
258 
259 template <typename GemmConfig_>
260 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
261  : public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
263  typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
264 
266  static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
268  typedef Shape<GemmConfig_::kStages,
269  GemmConfig_::OutputTile::kD,
270  GemmConfig_::OutputTile::kH + kSkew>
271  Tile;
272 
274  typedef WmmaMatrix<GemmOperand::kB,
276  typename Base::MultiplyAddScalar,
277  typename GemmConfig_::InstructionShape>
278  WmmaMatrix;
279 
281  typedef GemmSharedStoreTileAbTraits<
282  // The pointer.
283  typename Base::MultiplyAddScalar,
284  // The tile has size KxM in GEMM's terminology.
285  Tile,
286  // The threads are distributed as warps x 32 (the traits may reorganize).
287  typename Base::GlobalTileTraits::Threads,
288  // The number of scalars per STS (STS.32 or STS.128, etc).
289  GemmConfig_::kScalarsPerStsB>
290  SharedStoreTileTraits;
291 
293  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
295  static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
297  typedef WmmaGemmSharedLoadTileBTraits<
298  // The layout of the matrix.
300  // The pointer.
301  typename Base::MultiplyAddScalar,
302  // The output tile size.
303  Tile,
304  // The number of warps.
305  typename GemmConfig_::Warps,
306  // The strides between warps.
307  GemmConfig_::InstructionShape::kH,
308  // The number of iterations to load the data.
309  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
310  // The stride between iterations.
311  Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
312  // The shape of the instruction.
313  typename GemmConfig_::InstructionShape>
314  SharedLoadTileTraits;
315 };
316 
318 
319 template <typename GemmConfig_>
320 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
322  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
323 
325  typedef typename GemmConfig_::ScalarB Scalar;
327  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
328 
330  typedef WmmaMatrix<GemmOperand::kB,
332  MultiplyAddScalar,
333  typename GemmConfig_::InstructionShape>
334  WmmaMatrix;
335 
337  typedef GemmGlobalTileTraits<
338  // That's B.
340  // A is row-major.
342  // The pointer is float const.
343  Scalar const,
344  // The tile has size KxM in GEMM's terminology.
345  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
346  // The threads are distributed as warps x 32 (the traits may reorganize).
347  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
348  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
349  GemmConfig_::kScalarsPerLdgB>
350  GlobalTileTraits;
351 
353  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
355  typedef Shape<GemmConfig_::kStages,
356  GemmConfig_::OutputTile::kH,
357  GemmConfig_::OutputTile::kD + kSkew>
358  Tile;
359 
361  typedef GemmSharedStoreTileAbTraits<
362  // The pointer.
363  MultiplyAddScalar,
364  // The tile has size KxM in GEMM's terminology.
365  Tile,
366  // The threads are distributed as warps x 32 (the traits may reorganize).
367  typename GlobalTileTraits::Threads,
368  // The number of scalars per STS (STS.32 or STS.128, etc).
369  GemmConfig_::kScalarsPerStsB>
370  SharedStoreTileTraits;
371 
373  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
375  typedef WmmaGemmSharedLoadTileBTraits<
376  // The layout of the matrix.
378  // The pointer.
379  MultiplyAddScalar,
380  // The tile in shared memory.
381  Tile,
382  // The number of warps.
383  typename GemmConfig_::Warps,
384  // The strides between warps.
385  GemmConfig_::InstructionShape::kH * Tile::kW,
386  // The number of iterations to load the data.
387  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
388  // The stride between iterations.
389  Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
390  // The shape of the instruction.
391  typename GemmConfig_::InstructionShape>
392  SharedLoadTileTraits;
393 };
394 
396 
397 template <
399  MatrixLayout::Kind kLayoutA_,
401  MatrixLayout::Kind kLayoutB_,
403  typename OutputTile_,
405  typename ScalarC_,
407  typename Accumulator_,
409  typename EpilogueFunctor_,
411  typename AccumulatorsPerWarp_,
413  typename InstructionShape_,
415  int kScalarsPerLdgA_,
417  int kScalarsPerLdgB_,
419  typename Index_>
420 struct WmmaGemmTraitsHelper {
422  typedef WmmaGemmConfig<kLayoutA_,
423  kLayoutB_,
424  OutputTile_,
425  ScalarC_,
426  Accumulator_,
427  AccumulatorsPerWarp_,
428  InstructionShape_,
429  kScalarsPerLdgA_,
430  kScalarsPerLdgB_>
431  GemmConfig;
432 
434  typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
436  typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
437 
439  typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
440  GlobalLoadIteratorA;
442  typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
444  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
445  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
448  SharedStoreIteratorA;
450  typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
451  GlobalLoadStreamA;
452 
454  typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
455  GlobalLoadIteratorB;
456  // The default transformer for B.
457  typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
459  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
460  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
463  SharedStoreIteratorB;
465  typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
466  GlobalLoadStreamB;
467 
469  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
470  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
473  Index_,
474  typename GemmTileTraitsHelperA::WmmaMatrix,
476  SharedLoadIteratorA;
478  typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
480  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
481  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
484  Index_,
485  typename GemmTileTraitsHelperB::WmmaMatrix,
487  SharedLoadIteratorB;
489  typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
490 
492  typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
494  typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
495 
497  typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
499  typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
500  GemmEpilogueTraits;
502  typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
503 };
504 
506 
507 template <typename OutputTile_, typename DefaultShape_ = Shape<64, 32, 64> >
508 struct WmmaGemmAccumulatorsPerWarp {
509  typedef typename ShapeMin<OutputTile_, DefaultShape_>::Shape Shape;
510 };
511 
513 
514 template <
516  MatrixLayout::Kind kLayoutA_,
518  MatrixLayout::Kind kLayoutB_,
520  typename OutputTile_ = Shape<64, 128, 128>,
522  typename ScalarC_ = float,
524  typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
526  typename Accumulator_ = ScalarC_,
528  typename AccumulatorsPerWarp_ = typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
530  typename InstructionShape_ = Shape<16, 16, 16>,
532  int kScalarsPerLdgA_ = 8,
534  int kScalarsPerLdgB_ = 8,
536  typename Index_ = int,
538  typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
539  kLayoutB_,
540  OutputTile_,
541  ScalarC_,
542  Accumulator_,
543  EpilogueFunctor_,
544  AccumulatorsPerWarp_,
545  InstructionShape_,
546  kScalarsPerLdgA_,
547  kScalarsPerLdgB_,
548  Index_> >
549 struct WmmaGemmTraits : public GemmTraits<
550  // The config.
551  typename Helper_::GemmConfig,
552  // The stream to load A from global memory to shared memory.
553  typename Helper_::GlobalLoadStreamA,
554  // The stream to load B from global memory to shared memory.
555  typename Helper_::GlobalLoadStreamB,
556  // The stream to load A from shared memory.
557  typename Helper_::SharedLoadStreamA,
558  // The stream to load B from shared memory.
559  typename Helper_::SharedLoadStreamB,
560  // The epilogue.
561  typename Helper_::Epilogue,
562  // The block swizzle to reorganize the grid.
563  IdentityBlockSwizzle,
564  // The index.
565  Index_,
566  // The tool used to clear accumulators.
567  typename Helper_::ClearAccumulators> {};
568 
570 
571 } // namespace gemm
572 } // namespace cutlass
573 
574 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_traits.h:93
Definition: load_store.h:42
Definition: convert.h:33
Defines iterators for efficiently loading and storing to global memory.
Defines structural properties of complete GEMM computation.
Defines structural properties of WMMA GEMM&#39;s epilogue phase.
Definition: tile_iterator.h:62
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Defines iterators for efficiently loading and storing tiles to and from shared memory.
Definition: matrix_traits.h:36
Definition: tile_iterator.h:67
Definition: matrix_traits.h:43
Defines tile iterator traits for loading thread block-level tile from global memory.
Definition: matrix_traits.h:36
Kind
Definition: matrix_traits.h:36
Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
Definition: matrix_traits.h:43
Implements a software-pipelined efficient GEMM.
Defines structural properties of the GEMM epilogue.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:148
Defines conversion operations among Fragments of different base type.