31 #ifdef CUTLASS_USE_WMMA_API 63 typename Accumulator_,
65 typename WarpGemmShape_,
67 typename InstructionShape_,
72 struct WmmaGemmConfig :
public GemmConfig<
84 WmmaGemmMultiplyAdd<kLayoutA_,
88 MatrixLayout::kColumnMajor,
105 16 / sizeof(ScalarC_),
107 16 / sizeof(Accumulator_),
109 16 / sizeof(Accumulator_),
122 typename GemmConfig_,
124 struct WmmaGemmTileTraitsHelperA {};
128 template <
typename GemmConfig_,
typename ScalarA_>
129 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, ScalarA_>
130 :
public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
132 typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
135 static int const kSkew = 16 /
sizeof(
typename Base::MultiplyAddScalar);
137 typedef Shape<GemmConfig_::kStages,
138 GemmConfig_::OutputTile::kD,
139 GemmConfig_::OutputTile::kW + kSkew>
145 typename Base::MultiplyAddScalar,
146 typename GemmConfig_::InstructionShape>
150 typedef GemmSharedStoreTileAbTraits<
152 typename Base::MultiplyAddScalar,
156 typename Base::GlobalTileTraits::Threads,
158 GemmConfig_::kScalarsPerStsA>
159 SharedStoreTileTraits;
162 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
164 static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
166 typedef WmmaGemmSharedLoadTileATraits<
170 typename Base::MultiplyAddScalar,
174 typename GemmConfig_::Warps,
176 GemmConfig_::InstructionShape::kW,
178 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
180 Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
182 typename GemmConfig_::InstructionShape>
183 SharedLoadTileTraits;
188 template <
typename GemmConfig_,
typename ScalarA_>
189 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, ScalarA_> {
194 typedef typename GemmConfig_::ScalarA Scalar;
196 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
202 typename GemmConfig_::InstructionShape>
206 typedef GemmGlobalTileTraits<
214 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
216 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
218 GemmConfig_::kScalarsPerLdgA>
222 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
224 typedef Shape<GemmConfig_::kStages,
225 GemmConfig_::OutputTile::kW,
226 GemmConfig_::OutputTile::kD + kSkew>
230 typedef GemmSharedStoreTileAbTraits<
236 typename GlobalTileTraits::Threads,
238 GemmConfig_::kScalarsPerStsA>
239 SharedStoreTileTraits;
242 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
244 typedef WmmaGemmSharedLoadTileATraits<
252 typename GemmConfig_::Warps,
254 GemmConfig_::InstructionShape::kW * Tile::kW,
256 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
258 Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
260 typename GemmConfig_::InstructionShape>
261 SharedLoadTileTraits;
266 #ifdef CUTLASS_USE_SUBBYTE_WMMA 267 template <
typename GemmConfig_>
269 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<bin1_t, 32> > {
274 typedef typename GemmConfig_::ScalarA Scalar;
276 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
280 static int const kBitsPerScalar =
sizeof(Scalar) * 8;
286 typename GemmConfig_::InstructionShape>
290 typedef GemmGlobalTileTraits<
298 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
301 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
302 GemmConfig_::OutputTile::kD / kBitsPerScalar>,
304 GemmConfig_::kScalarsPerLdgA / kBitsPerScalar>
308 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
310 typedef Shape<GemmConfig_::kStages,
311 GemmConfig_::OutputTile::kW,
312 GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
316 typedef GemmSharedStoreTileAbTraits<
322 typename GlobalTileTraits::Threads,
324 GemmConfig_::kScalarsPerStsA / kBitsPerScalar>
325 SharedStoreTileTraits;
328 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
330 typedef WmmaGemmSharedLoadTileATraits<
338 typename GemmConfig_::Warps,
340 GemmConfig_::InstructionShape::kW * Tile::kW,
342 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
344 Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
346 typename GemmConfig_::InstructionShape>
347 SharedLoadTileTraits;
353 #ifdef CUTLASS_USE_SUBBYTE_WMMA 354 template <
typename GemmConfig_>
356 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<uint4_t, 8> > {
361 typedef typename GemmConfig_::ScalarA Scalar;
363 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
367 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
373 typename GemmConfig_::InstructionShape>
377 typedef GemmGlobalTileTraits<
385 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
388 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
389 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
391 GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
395 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
397 typedef Shape<GemmConfig_::kStages,
398 GemmConfig_::OutputTile::kW,
399 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
403 typedef GemmSharedStoreTileAbTraits<
409 typename GlobalTileTraits::Threads,
411 GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
412 SharedStoreTileTraits;
415 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
417 typedef WmmaGemmSharedLoadTileATraits<
425 typename GemmConfig_::Warps,
427 GemmConfig_::InstructionShape::kW * Tile::kW,
429 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
431 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
433 typename GemmConfig_::InstructionShape>
434 SharedLoadTileTraits;
440 #ifdef CUTLASS_USE_SUBBYTE_WMMA 441 template <
typename GemmConfig_>
443 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<int4_t, 8> > {
448 typedef typename GemmConfig_::ScalarA Scalar;
450 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
454 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
460 typename GemmConfig_::InstructionShape>
464 typedef GemmGlobalTileTraits<
472 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
475 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
476 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
478 GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
482 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
484 typedef Shape<GemmConfig_::kStages,
485 GemmConfig_::OutputTile::kW,
486 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
490 typedef GemmSharedStoreTileAbTraits<
496 typename GlobalTileTraits::Threads,
498 GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
499 SharedStoreTileTraits;
502 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
504 typedef WmmaGemmSharedLoadTileATraits<
512 typename GemmConfig_::Warps,
514 GemmConfig_::InstructionShape::kW * Tile::kW,
516 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
518 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
520 typename GemmConfig_::InstructionShape>
521 SharedLoadTileTraits;
528 typename GemmConfig_,
530 struct WmmaGemmTileTraitsHelperB {};
534 template <
typename GemmConfig_,
typename ScalarB_>
535 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, ScalarB_>
536 :
public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
538 typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
541 static int const kSkew = 16 /
sizeof(
typename Base::MultiplyAddScalar);
543 typedef Shape<GemmConfig_::kStages,
544 GemmConfig_::OutputTile::kD,
545 GemmConfig_::OutputTile::kH + kSkew>
551 typename Base::MultiplyAddScalar,
552 typename GemmConfig_::InstructionShape>
556 typedef GemmSharedStoreTileAbTraits<
558 typename Base::MultiplyAddScalar,
562 typename Base::GlobalTileTraits::Threads,
564 GemmConfig_::kScalarsPerStsB>
565 SharedStoreTileTraits;
568 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
570 static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
572 typedef WmmaGemmSharedLoadTileBTraits<
576 typename Base::MultiplyAddScalar,
580 typename GemmConfig_::Warps,
582 GemmConfig_::InstructionShape::kH,
584 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
586 Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
588 typename GemmConfig_::InstructionShape>
589 SharedLoadTileTraits;
594 template <
typename GemmConfig_,
typename ScalarB_>
595 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, ScalarB_> {
600 typedef typename GemmConfig_::ScalarB Scalar;
602 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
608 typename GemmConfig_::InstructionShape>
612 typedef GemmGlobalTileTraits<
620 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
622 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
624 GemmConfig_::kScalarsPerLdgB>
628 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
630 typedef Shape<GemmConfig_::kStages,
631 GemmConfig_::OutputTile::kH,
632 GemmConfig_::OutputTile::kD + kSkew>
636 typedef GemmSharedStoreTileAbTraits<
642 typename GlobalTileTraits::Threads,
644 GemmConfig_::kScalarsPerStsB>
645 SharedStoreTileTraits;
648 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
650 typedef WmmaGemmSharedLoadTileBTraits<
658 typename GemmConfig_::Warps,
660 GemmConfig_::InstructionShape::kH * Tile::kW,
662 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
664 Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
666 typename GemmConfig_::InstructionShape>
667 SharedLoadTileTraits;
672 #ifdef CUTLASS_USE_SUBBYTE_WMMA 673 template <
typename GemmConfig_>
675 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<bin1_t, 32> > {
680 typedef typename GemmConfig_::ScalarB Scalar;
682 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
686 static int const kBitsPerScalar =
sizeof(Scalar) * 8;
692 typename GemmConfig_::InstructionShape>
696 typedef GemmGlobalTileTraits<
704 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
707 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
708 GemmConfig_::OutputTile::kD / kBitsPerScalar>,
710 GemmConfig_::kScalarsPerLdgB / kBitsPerScalar>
714 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
716 typedef Shape<GemmConfig_::kStages,
717 GemmConfig_::OutputTile::kH,
718 GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
722 typedef GemmSharedStoreTileAbTraits<
728 typename GlobalTileTraits::Threads,
730 GemmConfig_::kScalarsPerStsB / kBitsPerScalar>
731 SharedStoreTileTraits;
734 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
736 typedef WmmaGemmSharedLoadTileBTraits<
744 typename GemmConfig_::Warps,
746 GemmConfig_::InstructionShape::kH * Tile::kW,
748 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
750 Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
752 typename GemmConfig_::InstructionShape>
753 SharedLoadTileTraits;
759 #ifdef CUTLASS_USE_SUBBYTE_WMMA 760 template <
typename GemmConfig_>
762 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<uint4_t, 8> > {
767 typedef typename GemmConfig_::ScalarB Scalar;
769 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
773 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
779 typename GemmConfig_::InstructionShape>
783 typedef GemmGlobalTileTraits<
791 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
794 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
795 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
797 GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
801 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
803 typedef Shape<GemmConfig_::kStages,
804 GemmConfig_::OutputTile::kH,
805 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
809 typedef GemmSharedStoreTileAbTraits<
815 typename GlobalTileTraits::Threads,
817 GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
818 SharedStoreTileTraits;
821 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
823 typedef WmmaGemmSharedLoadTileBTraits<
831 typename GemmConfig_::Warps,
833 GemmConfig_::InstructionShape::kH * Tile::kW,
835 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
837 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
839 typename GemmConfig_::InstructionShape>
840 SharedLoadTileTraits;
846 #ifdef CUTLASS_USE_SUBBYTE_WMMA 847 template <
typename GemmConfig_>
849 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<int4_t, 8> > {
854 typedef typename GemmConfig_::ScalarB Scalar;
856 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
860 static int const kInt4PerScalar =
sizeof(Scalar) * 2;
866 typename GemmConfig_::InstructionShape>
870 typedef GemmGlobalTileTraits<
878 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
881 GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
882 GemmConfig_::OutputTile::kD / kInt4PerScalar>,
884 GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
888 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
890 typedef Shape<GemmConfig_::kStages,
891 GemmConfig_::OutputTile::kH,
892 GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
896 typedef GemmSharedStoreTileAbTraits<
902 typename GlobalTileTraits::Threads,
904 GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
905 SharedStoreTileTraits;
908 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
910 typedef WmmaGemmSharedLoadTileBTraits<
918 typename GemmConfig_::Warps,
920 GemmConfig_::InstructionShape::kH * Tile::kW,
922 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
924 Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
926 typename GemmConfig_::InstructionShape>
927 SharedLoadTileTraits;
939 typename OutputTile_,
947 typename Accumulator_,
949 typename EpilogueFunctor_,
951 typename WarpGemmShape_,
953 typename InstructionShape_,
955 int kScalarsPerLdgA_,
957 int kScalarsPerLdgB_,
960 struct WmmaGemmTraitsHelper {
962 typedef WmmaGemmConfig<kLayoutA_,
976 typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig, ScalarA_> GemmTileTraitsHelperA;
978 typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig, ScalarB_> GemmTileTraitsHelperB;
981 typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
984 typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
986 typedef TileStoreIterator<
typename GemmTileTraitsHelperA::SharedStoreTileTraits,
987 typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
990 SharedStoreIteratorA;
994 SharedStoreIteratorA,
999 typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
1000 GlobalLoadIteratorB;
1002 typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
1004 typedef TileStoreIterator<
typename GemmTileTraitsHelperB::SharedStoreTileTraits,
1005 typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
1008 SharedStoreIteratorB;
1011 GlobalLoadIteratorB,
1012 SharedStoreIteratorB,
1017 typedef TileLoadIterator<
typename GemmTileTraitsHelperA::SharedLoadTileTraits,
1018 typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
1022 typename GemmTileTraitsHelperA::WmmaMatrix,
1024 SharedLoadIteratorA;
1026 typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
1028 typedef TileLoadIterator<
typename GemmTileTraitsHelperB::SharedLoadTileTraits,
1029 typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
1033 typename GemmTileTraitsHelperB::WmmaMatrix,
1035 SharedLoadIteratorB;
1037 typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
1042 typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
1045 typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
1047 typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
1050 typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
1055 template <
typename OutputTile_,
typename DefaultShape_ = Shape<64, 32, 64> >
1056 struct WmmaGemmAccumulatorsPerWarp {
1068 typename OutputTile_ = Shape<64, 128, 128>,
1070 typename ScalarA_ = half,
1072 typename ScalarB_ = half,
1074 typename ScalarC_ = float,
1076 typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
1078 typename Accumulator_ = ScalarC_,
1080 typename WarpGemmShape_ =
typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
1082 typename InstructionShape_ = Shape<16, 16, 16>,
1084 int kScalarsPerLdgA_ = 8,
1086 int kScalarsPerLdgB_ = 8,
1088 typename Index_ = int,
1090 typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
1103 struct WmmaGemmTraits :
public GemmTraits<
1105 typename Helper_::GemmConfig,
1107 typename Helper_::GlobalLoadStreamA,
1109 typename Helper_::GlobalLoadStreamB,
1111 typename Helper_::SharedLoadStreamA,
1113 typename Helper_::SharedLoadStreamB,
1115 typename Helper_::Epilogue,
1117 IdentityBlockSwizzle,
1121 typename Helper_::ClearAccumulators> {};
1128 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Defines iterators for efficiently loading and storing to global memory.
Defines structural properties of complete GEMM computation.
Defines structural properties of WMMA GEMM's epilogue phase.
Definition: load_store.h:48
Definition: tile_iterator.h:65
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Defines iterators for efficiently loading and storing tiles to and from shared memory.
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Defines tile iterator traits for loading thread block-level tile from global memory.
Definition: matrix_traits.h:159
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
Definition: matrix_traits.h:357
Implements a software-pipelined efficient GEMM.
Defines structural properties of the GEMM epilogue.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:159
Defines conversion operations among Fragments of different base type.