Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/wmma_matrix.h"
31 #ifdef CUTLASS_USE_WMMA_API
32 
33 #include "cutlass/convert.h"
34 #include "cutlass/gemm/gemm.h"
43 
44 namespace cutlass {
45 namespace gemm {
46 
48 
49 template <
51  MatrixLayout::Kind kLayoutA_,
53  MatrixLayout::Kind kLayoutB_,
55  typename OutputTile_,
57  typename ScalarA_,
59  typename ScalarB_,
61  typename ScalarC_,
63  typename Accumulator_,
65  typename WarpGemmShape_,
67  typename InstructionShape_,
69  int kScalarsPerLdgA_,
71  int kScalarsPerLdgB_>
72 struct WmmaGemmConfig : public GemmConfig<
74  ScalarA_,
76  ScalarB_,
78  ScalarC_,
80  ScalarC_,
82  OutputTile_,
84  WmmaGemmMultiplyAdd<kLayoutA_,
85  ScalarA_,
86  kLayoutB_,
87  ScalarB_,
88  MatrixLayout::kColumnMajor,
89  Accumulator_,
90  WarpGemmShape_,
91  InstructionShape_>,
93  kScalarsPerLdgA_,
95  kScalarsPerLdgA_,
97  8,
99  kScalarsPerLdgB_,
101  kScalarsPerLdgB_,
103  8,
105  16 / sizeof(ScalarC_),
107  16 / sizeof(Accumulator_),
109  16 / sizeof(Accumulator_),
111  1,
113  false,
115  true,
117  false> {};
118 
120 
121 template <enum MatrixLayout::Kind kLayout_,
122  typename GemmConfig_,
123  typename ScalarA_>
124 struct WmmaGemmTileTraitsHelperA {};
125 
127 
128 template <typename GemmConfig_, typename ScalarA_>
129 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, ScalarA_>
130  : public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
132  typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
133 
135  static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
137  typedef Shape<GemmConfig_::kStages,
138  GemmConfig_::OutputTile::kD,
139  GemmConfig_::OutputTile::kW + kSkew>
140  Tile;
141 
143  typedef WmmaMatrix<GemmOperand::kA,
145  typename Base::MultiplyAddScalar,
146  typename GemmConfig_::InstructionShape>
147  WmmaMatrix;
148 
150  typedef GemmSharedStoreTileAbTraits<
151  // The pointer.
152  typename Base::MultiplyAddScalar,
153  // The tile has size KxM in GEMM's terminology.
154  Tile,
155  // The threads are distributed as warps x 32 (the traits may reorganize).
156  typename Base::GlobalTileTraits::Threads,
157  // The number of scalars per STS (STS.32 or STS.128, etc).
158  GemmConfig_::kScalarsPerStsA>
159  SharedStoreTileTraits;
160 
162  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
164  static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
166  typedef WmmaGemmSharedLoadTileATraits<
167  // The layout of the matrix.
169  // The pointer.
170  typename Base::MultiplyAddScalar,
171  // The output tile size.
172  Tile,
173  // The number of warps.
174  typename GemmConfig_::Warps,
175  // The strides between warps.
176  GemmConfig_::InstructionShape::kW,
177  // The number of iterations to load the data.
178  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
179  // The stride between iterations.
180  Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
181  // The shape of the instruction.
182  typename GemmConfig_::InstructionShape>
183  SharedLoadTileTraits;
184 };
185 
187 
188 template <typename GemmConfig_, typename ScalarA_>
189 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, ScalarA_> {
191  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
192 
194  typedef typename GemmConfig_::ScalarA Scalar;
196  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
197 
199  typedef WmmaMatrix<GemmOperand::kA,
201  MultiplyAddScalar,
202  typename GemmConfig_::InstructionShape>
203  WmmaMatrix;
204 
206  typedef GemmGlobalTileTraits<
207  // That's A.
209  // A is row-major.
211  // The pointer is float const.
212  Scalar const,
213  // The tile has size KxM in GEMM's terminology.
214  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
215  // The threads are distributed as warps x 32 (the traits may reorganize).
216  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
217  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
218  GemmConfig_::kScalarsPerLdgA>
219  GlobalTileTraits;
220 
222  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
224  typedef Shape<GemmConfig_::kStages,
225  GemmConfig_::OutputTile::kW,
226  GemmConfig_::OutputTile::kD + kSkew>
227  Tile;
228 
230  typedef GemmSharedStoreTileAbTraits<
231  // The pointer.
232  MultiplyAddScalar,
233  // The tile has size KxM in GEMM's terminology.
234  Tile,
235  // The threads are distributed as warps x 32 (the traits may reorganize).
236  typename GlobalTileTraits::Threads,
237  // The number of scalars per STS (STS.32 or STS.128, etc).
238  GemmConfig_::kScalarsPerStsA>
239  SharedStoreTileTraits;
240 
242  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
244  typedef WmmaGemmSharedLoadTileATraits<
245  // The layout of the matrix.
247  // The pointer.
248  MultiplyAddScalar,
249  // The tile in shared memory.
250  Tile,
251  // The number of warps.
252  typename GemmConfig_::Warps,
253  // The strides between warps.
254  GemmConfig_::InstructionShape::kW * Tile::kW,
255  // The number of iterations to load the data.
256  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
257  // The stride between iterations.
258  Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
259  // The shape of the instruction.
260  typename GemmConfig_::InstructionShape>
261  SharedLoadTileTraits;
262 };
263 
265 
266 #ifdef CUTLASS_USE_SUBBYTE_WMMA
267 template <typename GemmConfig_>
269 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<bin1_t, 32> > {
271  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
272 
274  typedef typename GemmConfig_::ScalarA Scalar;
276  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
277 
280  static int const kBitsPerScalar = sizeof(Scalar) * 8;
281 
283  typedef WmmaMatrix<GemmOperand::kA,
285  Vector<bin1_t, 32>,
286  typename GemmConfig_::InstructionShape>
287  WmmaMatrix;
288 
290  typedef GemmGlobalTileTraits<
291  // That's A.
293  // A is row-major.
295  // The pointer is float const.
296  Scalar const,
297  // The tile has size KxM in GEMM's terminology.
298  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
299  // The threads are distributed as warps x 32 (the traits may reorganize).
300  Shape<1,
301  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
302  GemmConfig_::OutputTile::kD / kBitsPerScalar>,
303  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
304  GemmConfig_::kScalarsPerLdgA / kBitsPerScalar>
305  GlobalTileTraits;
306 
308  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
310  typedef Shape<GemmConfig_::kStages,
311  GemmConfig_::OutputTile::kW,
312  GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
313  Tile;
314 
316  typedef GemmSharedStoreTileAbTraits<
317  // The pointer.
318  MultiplyAddScalar,
319  // The tile has size KxM in GEMM's terminology.
320  Tile,
321  // The threads are distributed as warps x 32 (the traits may reorganize).
322  typename GlobalTileTraits::Threads,
323  // The number of scalars per STS (STS.32 or STS.128, etc).
324  GemmConfig_::kScalarsPerStsA / kBitsPerScalar>
325  SharedStoreTileTraits;
326 
328  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
330  typedef WmmaGemmSharedLoadTileATraits<
331  // The layout of the matrix.
333  // The pointer.
334  MultiplyAddScalar,
335  // The tile in shared memory.
336  Tile,
337  // The number of warps.
338  typename GemmConfig_::Warps,
339  // The strides between warps.
340  GemmConfig_::InstructionShape::kW * Tile::kW,
341  // The number of iterations to load the data.
342  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
343  // The stride between iterations.
344  Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
345  // The shape of the instruction.
346  typename GemmConfig_::InstructionShape>
347  SharedLoadTileTraits;
348 };
349 #endif
350 
352 
353 #ifdef CUTLASS_USE_SUBBYTE_WMMA
354 template <typename GemmConfig_>
356 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<uint4_t, 8> > {
358  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
359 
361  typedef typename GemmConfig_::ScalarA Scalar;
363  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
364 
367  static int const kInt4PerScalar = sizeof(Scalar) * 2;
368 
370  typedef WmmaMatrix<GemmOperand::kA,
372  Vector<uint4_t, 8>,
373  typename GemmConfig_::InstructionShape>
374  WmmaMatrix;
375 
377  typedef GemmGlobalTileTraits<
378  // That's A.
380  // A is row-major.
382  // The pointer is float const.
383  Scalar const,
384  // The tile has size KxM in GEMM's terminology.
385  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
386  // The threads are distributed as warps x 32 (the traits may reorganize).
387  Shape<1,
388  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
389  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
390  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
391  GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
392  GlobalTileTraits;
393 
395  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
397  typedef Shape<GemmConfig_::kStages,
398  GemmConfig_::OutputTile::kW,
399  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
400  Tile;
401 
403  typedef GemmSharedStoreTileAbTraits<
404  // The pointer.
405  MultiplyAddScalar,
406  // The tile has size KxM in GEMM's terminology.
407  Tile,
408  // The threads are distributed as warps x 32 (the traits may reorganize).
409  typename GlobalTileTraits::Threads,
410  // The number of scalars per STS (STS.32 or STS.128, etc).
411  GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
412  SharedStoreTileTraits;
413 
415  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
417  typedef WmmaGemmSharedLoadTileATraits<
418  // The layout of the matrix.
420  // The pointer.
421  MultiplyAddScalar,
422  // The tile in shared memory.
423  Tile,
424  // The number of warps.
425  typename GemmConfig_::Warps,
426  // The strides between warps.
427  GemmConfig_::InstructionShape::kW * Tile::kW,
428  // The number of iterations to load the data.
429  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
430  // The stride between iterations.
431  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
432  // The shape of the instruction.
433  typename GemmConfig_::InstructionShape>
434  SharedLoadTileTraits;
435 };
436 #endif
437 
439 
440 #ifdef CUTLASS_USE_SUBBYTE_WMMA
441 template <typename GemmConfig_>
443 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<int4_t, 8> > {
445  static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
446 
448  typedef typename GemmConfig_::ScalarA Scalar;
450  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
451 
454  static int const kInt4PerScalar = sizeof(Scalar) * 2;
455 
457  typedef WmmaMatrix<GemmOperand::kA,
459  Vector<int4_t, 8>,
460  typename GemmConfig_::InstructionShape>
461  WmmaMatrix;
462 
464  typedef GemmGlobalTileTraits<
465  // That's A.
467  // A is row-major.
469  // The pointer is float const.
470  Scalar const,
471  // The tile has size KxM in GEMM's terminology.
472  Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
473  // The threads are distributed as warps x 32 (the traits may reorganize).
474  Shape<1,
475  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
476  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
477  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
478  GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
479  GlobalTileTraits;
480 
482  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
484  typedef Shape<GemmConfig_::kStages,
485  GemmConfig_::OutputTile::kW,
486  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
487  Tile;
488 
490  typedef GemmSharedStoreTileAbTraits<
491  // The pointer.
492  MultiplyAddScalar,
493  // The tile has size KxM in GEMM's terminology.
494  Tile,
495  // The threads are distributed as warps x 32 (the traits may reorganize).
496  typename GlobalTileTraits::Threads,
497  // The number of scalars per STS (STS.32 or STS.128, etc).
498  GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
499  SharedStoreTileTraits;
500 
502  static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
504  typedef WmmaGemmSharedLoadTileATraits<
505  // The layout of the matrix.
507  // The pointer.
508  MultiplyAddScalar,
509  // The tile in shared memory.
510  Tile,
511  // The number of warps.
512  typename GemmConfig_::Warps,
513  // The strides between warps.
514  GemmConfig_::InstructionShape::kW * Tile::kW,
515  // The number of iterations to load the data.
516  Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
517  // The stride between iterations.
518  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
519  // The shape of the instruction.
520  typename GemmConfig_::InstructionShape>
521  SharedLoadTileTraits;
522 };
523 #endif
524 
526 
527 template <enum MatrixLayout::Kind kLayout_,
528  typename GemmConfig_,
529  typename ScalarB_>
530 struct WmmaGemmTileTraitsHelperB {};
531 
533 
534 template <typename GemmConfig_, typename ScalarB_>
535 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, ScalarB_>
536  : public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
538  typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
539 
541  static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
543  typedef Shape<GemmConfig_::kStages,
544  GemmConfig_::OutputTile::kD,
545  GemmConfig_::OutputTile::kH + kSkew>
546  Tile;
547 
549  typedef WmmaMatrix<GemmOperand::kB,
551  typename Base::MultiplyAddScalar,
552  typename GemmConfig_::InstructionShape>
553  WmmaMatrix;
554 
556  typedef GemmSharedStoreTileAbTraits<
557  // The pointer.
558  typename Base::MultiplyAddScalar,
559  // The tile has size KxM in GEMM's terminology.
560  Tile,
561  // The threads are distributed as warps x 32 (the traits may reorganize).
562  typename Base::GlobalTileTraits::Threads,
563  // The number of scalars per STS (STS.32 or STS.128, etc).
564  GemmConfig_::kScalarsPerStsB>
565  SharedStoreTileTraits;
566 
568  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
570  static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
572  typedef WmmaGemmSharedLoadTileBTraits<
573  // The layout of the matrix.
575  // The pointer.
576  typename Base::MultiplyAddScalar,
577  // The output tile size.
578  Tile,
579  // The number of warps.
580  typename GemmConfig_::Warps,
581  // The strides between warps.
582  GemmConfig_::InstructionShape::kH,
583  // The number of iterations to load the data.
584  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
585  // The stride between iterations.
586  Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
587  // The shape of the instruction.
588  typename GemmConfig_::InstructionShape>
589  SharedLoadTileTraits;
590 };
591 
593 
594 template <typename GemmConfig_, typename ScalarB_>
595 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, ScalarB_> {
597  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
598 
600  typedef typename GemmConfig_::ScalarB Scalar;
602  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
603 
605  typedef WmmaMatrix<GemmOperand::kB,
607  MultiplyAddScalar,
608  typename GemmConfig_::InstructionShape>
609  WmmaMatrix;
610 
612  typedef GemmGlobalTileTraits<
613  // That's B.
615  // A is row-major.
617  // The pointer is float const.
618  Scalar const,
619  // The tile has size KxM in GEMM's terminology.
620  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
621  // The threads are distributed as warps x 32 (the traits may reorganize).
622  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
623  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
624  GemmConfig_::kScalarsPerLdgB>
625  GlobalTileTraits;
626 
628  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
630  typedef Shape<GemmConfig_::kStages,
631  GemmConfig_::OutputTile::kH,
632  GemmConfig_::OutputTile::kD + kSkew>
633  Tile;
634 
636  typedef GemmSharedStoreTileAbTraits<
637  // The pointer.
638  MultiplyAddScalar,
639  // The tile has size KxM in GEMM's terminology.
640  Tile,
641  // The threads are distributed as warps x 32 (the traits may reorganize).
642  typename GlobalTileTraits::Threads,
643  // The number of scalars per STS (STS.32 or STS.128, etc).
644  GemmConfig_::kScalarsPerStsB>
645  SharedStoreTileTraits;
646 
648  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
650  typedef WmmaGemmSharedLoadTileBTraits<
651  // The layout of the matrix.
653  // The pointer.
654  MultiplyAddScalar,
655  // The tile in shared memory.
656  Tile,
657  // The number of warps.
658  typename GemmConfig_::Warps,
659  // The strides between warps.
660  GemmConfig_::InstructionShape::kH * Tile::kW,
661  // The number of iterations to load the data.
662  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
663  // The stride between iterations.
664  Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
665  // The shape of the instruction.
666  typename GemmConfig_::InstructionShape>
667  SharedLoadTileTraits;
668 };
669 
671 
672 #ifdef CUTLASS_USE_SUBBYTE_WMMA
673 template <typename GemmConfig_>
675 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<bin1_t, 32> > {
677  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
678 
680  typedef typename GemmConfig_::ScalarB Scalar;
682  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
683 
686  static int const kBitsPerScalar = sizeof(Scalar) * 8;
687 
689  typedef WmmaMatrix<GemmOperand::kB,
691  Vector<bin1_t, 32>,
692  typename GemmConfig_::InstructionShape>
693  WmmaMatrix;
694 
696  typedef GemmGlobalTileTraits<
697  // That's B.
699  // A is row-major.
701  // The pointer is float const.
702  Scalar const,
703  // The tile has size KxM in GEMM's terminology.
704  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
705  // The threads are distributed as warps x 32 (the traits may reorganize).
706  Shape<1,
707  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
708  GemmConfig_::OutputTile::kD / kBitsPerScalar>,
709  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
710  GemmConfig_::kScalarsPerLdgB / kBitsPerScalar>
711  GlobalTileTraits;
712 
714  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
716  typedef Shape<GemmConfig_::kStages,
717  GemmConfig_::OutputTile::kH,
718  GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
719  Tile;
720 
722  typedef GemmSharedStoreTileAbTraits<
723  // The pointer.
724  MultiplyAddScalar,
725  // The tile has size KxM in GEMM's terminology.
726  Tile,
727  // The threads are distributed as warps x 32 (the traits may reorganize).
728  typename GlobalTileTraits::Threads,
729  // The number of scalars per STS (STS.32 or STS.128, etc).
730  GemmConfig_::kScalarsPerStsB / kBitsPerScalar>
731  SharedStoreTileTraits;
732 
734  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
736  typedef WmmaGemmSharedLoadTileBTraits<
737  // The layout of the matrix.
739  // The pointer.
740  MultiplyAddScalar,
741  // The tile in shared memory.
742  Tile,
743  // The number of warps.
744  typename GemmConfig_::Warps,
745  // The strides between warps.
746  GemmConfig_::InstructionShape::kH * Tile::kW,
747  // The number of iterations to load the data.
748  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
749  // The stride between iterations.
750  Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
751  // The shape of the instruction.
752  typename GemmConfig_::InstructionShape>
753  SharedLoadTileTraits;
754 };
755 #endif
756 
758 
759 #ifdef CUTLASS_USE_SUBBYTE_WMMA
760 template <typename GemmConfig_>
762 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<uint4_t, 8> > {
764  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
765 
767  typedef typename GemmConfig_::ScalarB Scalar;
769  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
770 
773  static int const kInt4PerScalar = sizeof(Scalar) * 2;
774 
776  typedef WmmaMatrix<GemmOperand::kB,
778  Vector<uint4_t, 8>,
779  typename GemmConfig_::InstructionShape>
780  WmmaMatrix;
781 
783  typedef GemmGlobalTileTraits<
784  // That's B.
786  // A is row-major.
788  // The pointer is float const.
789  Scalar const,
790  // The tile has size KxM in GEMM's terminology.
791  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
792  // The threads are distributed as warps x 32 (the traits may reorganize).
793  Shape<1,
794  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
795  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
796  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
797  GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
798  GlobalTileTraits;
799 
801  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
803  typedef Shape<GemmConfig_::kStages,
804  GemmConfig_::OutputTile::kH,
805  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
806  Tile;
807 
809  typedef GemmSharedStoreTileAbTraits<
810  // The pointer.
811  MultiplyAddScalar,
812  // The tile has size KxM in GEMM's terminology.
813  Tile,
814  // The threads are distributed as warps x 32 (the traits may reorganize).
815  typename GlobalTileTraits::Threads,
816  // The number of scalars per STS (STS.32 or STS.128, etc).
817  GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
818  SharedStoreTileTraits;
819 
821  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
823  typedef WmmaGemmSharedLoadTileBTraits<
824  // The layout of the matrix.
826  // The pointer.
827  MultiplyAddScalar,
828  // The tile in shared memory.
829  Tile,
830  // The number of warps.
831  typename GemmConfig_::Warps,
832  // The strides between warps.
833  GemmConfig_::InstructionShape::kH * Tile::kW,
834  // The number of iterations to load the data.
835  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
836  // The stride between iterations.
837  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
838  // The shape of the instruction.
839  typename GemmConfig_::InstructionShape>
840  SharedLoadTileTraits;
841 };
842 #endif
843 
845 
846 #ifdef CUTLASS_USE_SUBBYTE_WMMA
847 template <typename GemmConfig_>
849 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<int4_t, 8> > {
851  static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
852 
854  typedef typename GemmConfig_::ScalarB Scalar;
856  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
857 
860  static int const kInt4PerScalar = sizeof(Scalar) * 2;
861 
863  typedef WmmaMatrix<GemmOperand::kB,
865  Vector<int4_t, 8>,
866  typename GemmConfig_::InstructionShape>
867  WmmaMatrix;
868 
870  typedef GemmGlobalTileTraits<
871  // That's B.
873  // A is row-major.
875  // The pointer is float const.
876  Scalar const,
877  // The tile has size KxM in GEMM's terminology.
878  Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
879  // The threads are distributed as warps x 32 (the traits may reorganize).
880  Shape<1,
881  GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
882  GemmConfig_::OutputTile::kD / kInt4PerScalar>,
883  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
884  GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
885  GlobalTileTraits;
886 
888  static int const kSkew = 16 / sizeof(MultiplyAddScalar);
890  typedef Shape<GemmConfig_::kStages,
891  GemmConfig_::OutputTile::kH,
892  GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
893  Tile;
894 
896  typedef GemmSharedStoreTileAbTraits<
897  // The pointer.
898  MultiplyAddScalar,
899  // The tile has size KxM in GEMM's terminology.
900  Tile,
901  // The threads are distributed as warps x 32 (the traits may reorganize).
902  typename GlobalTileTraits::Threads,
903  // The number of scalars per STS (STS.32 or STS.128, etc).
904  GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
905  SharedStoreTileTraits;
906 
908  static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
910  typedef WmmaGemmSharedLoadTileBTraits<
911  // The layout of the matrix.
913  // The pointer.
914  MultiplyAddScalar,
915  // The tile in shared memory.
916  Tile,
917  // The number of warps.
918  typename GemmConfig_::Warps,
919  // The strides between warps.
920  GemmConfig_::InstructionShape::kH * Tile::kW,
921  // The number of iterations to load the data.
922  Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
923  // The stride between iterations.
924  Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
925  // The shape of the instruction.
926  typename GemmConfig_::InstructionShape>
927  SharedLoadTileTraits;
928 };
929 #endif
930 
932 
933 template <
935  MatrixLayout::Kind kLayoutA_,
937  MatrixLayout::Kind kLayoutB_,
939  typename OutputTile_,
941  typename ScalarA_,
943  typename ScalarB_,
945  typename ScalarC_,
947  typename Accumulator_,
949  typename EpilogueFunctor_,
951  typename WarpGemmShape_,
953  typename InstructionShape_,
955  int kScalarsPerLdgA_,
957  int kScalarsPerLdgB_,
959  typename Index_>
960 struct WmmaGemmTraitsHelper {
962  typedef WmmaGemmConfig<kLayoutA_,
963  kLayoutB_,
964  OutputTile_,
965  ScalarA_,
966  ScalarB_,
967  ScalarC_,
968  Accumulator_,
969  WarpGemmShape_,
970  InstructionShape_,
971  kScalarsPerLdgA_,
972  kScalarsPerLdgB_>
973  GemmConfig;
974 
976  typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig, ScalarA_> GemmTileTraitsHelperA;
978  typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig, ScalarB_> GemmTileTraitsHelperB;
979 
981  typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
982  GlobalLoadIteratorA;
984  typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
986  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
987  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
990  SharedStoreIteratorA;
992  typedef GlobalLoadStream<GemmOperand::kA,
993  GlobalLoadIteratorA,
994  SharedStoreIteratorA,
995  GlobalTransformerA>
996  GlobalLoadStreamA;
997 
999  typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
1000  GlobalLoadIteratorB;
1001  // The default transformer for B.
1002  typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
1004  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
1005  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
1008  SharedStoreIteratorB;
1010  typedef GlobalLoadStream<GemmOperand::kB,
1011  GlobalLoadIteratorB,
1012  SharedStoreIteratorB,
1013  GlobalTransformerB>
1014  GlobalLoadStreamB;
1015 
1017  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
1018  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
1021  Index_,
1022  typename GemmTileTraitsHelperA::WmmaMatrix,
1024  SharedLoadIteratorA;
1026  typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
1028  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
1029  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
1032  Index_,
1033  typename GemmTileTraitsHelperB::WmmaMatrix,
1035  SharedLoadIteratorB;
1037  typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
1038 
1040  typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
1042  typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
1043 
1045  typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
1047  typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
1048  GemmEpilogueTraits;
1050  typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
1051 };
1052 
1054 
1055 template <typename OutputTile_, typename DefaultShape_ = Shape<64, 32, 64> >
1056 struct WmmaGemmAccumulatorsPerWarp {
1057  typedef typename ShapeMin<OutputTile_, DefaultShape_>::Shape Shape;
1058 };
1059 
1061 
1062 template <
1064  MatrixLayout::Kind kLayoutA_,
1066  MatrixLayout::Kind kLayoutB_,
1068  typename OutputTile_ = Shape<64, 128, 128>,
1070  typename ScalarA_ = half,
1072  typename ScalarB_ = half,
1074  typename ScalarC_ = float,
1076  typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
1078  typename Accumulator_ = ScalarC_,
1080  typename WarpGemmShape_ = typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
1082  typename InstructionShape_ = Shape<16, 16, 16>,
1084  int kScalarsPerLdgA_ = 8,
1086  int kScalarsPerLdgB_ = 8,
1088  typename Index_ = int,
1090  typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
1091  kLayoutB_,
1092  OutputTile_,
1093  ScalarA_,
1094  ScalarB_,
1095  ScalarC_,
1096  Accumulator_,
1097  EpilogueFunctor_,
1098  WarpGemmShape_,
1099  InstructionShape_,
1100  kScalarsPerLdgA_,
1101  kScalarsPerLdgB_,
1102  Index_> >
1103 struct WmmaGemmTraits : public GemmTraits<
1104  // The config.
1105  typename Helper_::GemmConfig,
1106  // The stream to load A from global memory to shared memory.
1107  typename Helper_::GlobalLoadStreamA,
1108  // The stream to load B from global memory to shared memory.
1109  typename Helper_::GlobalLoadStreamB,
1110  // The stream to load A from shared memory.
1111  typename Helper_::SharedLoadStreamA,
1112  // The stream to load B from shared memory.
1113  typename Helper_::SharedLoadStreamB,
1114  // The epilogue.
1115  typename Helper_::Epilogue,
1116  // The block swizzle to reorganize the grid.
1117  IdentityBlockSwizzle,
1118  // The index.
1119  Index_,
1120  // The tool used to clear accumulators.
1121  typename Helper_::ClearAccumulators> {};
1122 
1124 
1125 } // namespace gemm
1126 } // namespace cutlass
1127 
1128 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Definition: convert.h:33
Defines iterators for efficiently loading and storing to global memory.
Defines structural properties of complete GEMM computation.
Defines structural properties of WMMA GEMM&#39;s epilogue phase.
Definition: tile_iterator.h:65
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Defines iterators for efficiently loading and storing tiles to and from shared memory.
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Defines tile iterator traits for loading thread block-level tile from global memory.
Definition: matrix_traits.h:159
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
Definition: matrix_traits.h:357
Implements a software-pipelined efficient GEMM.
Defines structural properties of the GEMM epilogue.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:159
Defines conversion operations among Fragments of different base type.