31 #ifdef CUTLASS_USE_WMMA_API 59 typename Accumulator_,
61 typename AccumulatorsPerWarp_,
63 typename InstructionShape_,
68 struct WmmaGemmConfig :
public GemmConfig<
80 WmmaGemmMultiplyAdd<kLayoutA_,
84 MatrixLayout::kColumnMajor,
101 16 / sizeof(ScalarC_),
103 16 / sizeof(ScalarC_),
105 16 / sizeof(ScalarC_),
111 template <enum MatrixLayout::Kind kLayout_,
typename GemmConfig_>
112 struct WmmaGemmTileTraitsHelperA {};
116 template <
typename GemmConfig_>
117 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
118 :
public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
120 typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
123 static int const kSkew = 16 /
sizeof(
typename Base::MultiplyAddScalar);
125 typedef Shape<GemmConfig_::kStages,
126 GemmConfig_::OutputTile::kD,
127 GemmConfig_::OutputTile::kW + kSkew>
133 typename Base::MultiplyAddScalar,
134 typename GemmConfig_::InstructionShape>
138 typedef GemmSharedStoreTileAbTraits<
140 typename Base::MultiplyAddScalar,
144 typename Base::GlobalTileTraits::Threads,
146 GemmConfig_::kScalarsPerStsA>
147 SharedStoreTileTraits;
150 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
152 static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
154 typedef WmmaGemmSharedLoadTileATraits<
158 typename Base::MultiplyAddScalar,
162 typename GemmConfig_::Warps,
164 GemmConfig_::InstructionShape::kW,
166 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
168 Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
170 typename GemmConfig_::InstructionShape>
171 SharedLoadTileTraits;
176 template <
typename GemmConfig_>
177 struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
182 typedef typename GemmConfig_::ScalarA Scalar;
184 typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
190 typename GemmConfig_::InstructionShape>
194 typedef GemmGlobalTileTraits<
202 Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
204 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
206 GemmConfig_::kScalarsPerLdgA>
210 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
212 typedef Shape<GemmConfig_::kStages,
213 GemmConfig_::OutputTile::kW,
214 GemmConfig_::OutputTile::kD + kSkew>
218 typedef GemmSharedStoreTileAbTraits<
224 typename GlobalTileTraits::Threads,
226 GemmConfig_::kScalarsPerStsA>
227 SharedStoreTileTraits;
230 static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
232 typedef WmmaGemmSharedLoadTileATraits<
240 typename GemmConfig_::Warps,
242 GemmConfig_::InstructionShape::kW * Tile::kW,
244 Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
246 Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
248 typename GemmConfig_::InstructionShape>
249 SharedLoadTileTraits;
254 template <enum MatrixLayout::Kind kLayout_,
typename GemmConfig_>
255 struct WmmaGemmTileTraitsHelperB {};
259 template <
typename GemmConfig_>
260 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
261 :
public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
263 typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
266 static int const kSkew = 16 /
sizeof(
typename Base::MultiplyAddScalar);
268 typedef Shape<GemmConfig_::kStages,
269 GemmConfig_::OutputTile::kD,
270 GemmConfig_::OutputTile::kH + kSkew>
276 typename Base::MultiplyAddScalar,
277 typename GemmConfig_::InstructionShape>
281 typedef GemmSharedStoreTileAbTraits<
283 typename Base::MultiplyAddScalar,
287 typename Base::GlobalTileTraits::Threads,
289 GemmConfig_::kScalarsPerStsB>
290 SharedStoreTileTraits;
293 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
295 static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
297 typedef WmmaGemmSharedLoadTileBTraits<
301 typename Base::MultiplyAddScalar,
305 typename GemmConfig_::Warps,
307 GemmConfig_::InstructionShape::kH,
309 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
311 Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
313 typename GemmConfig_::InstructionShape>
314 SharedLoadTileTraits;
319 template <
typename GemmConfig_>
320 struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
325 typedef typename GemmConfig_::ScalarB Scalar;
327 typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
333 typename GemmConfig_::InstructionShape>
337 typedef GemmGlobalTileTraits<
345 Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
347 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
349 GemmConfig_::kScalarsPerLdgB>
353 static int const kSkew = 16 /
sizeof(MultiplyAddScalar);
355 typedef Shape<GemmConfig_::kStages,
356 GemmConfig_::OutputTile::kH,
357 GemmConfig_::OutputTile::kD + kSkew>
361 typedef GemmSharedStoreTileAbTraits<
367 typename GlobalTileTraits::Threads,
369 GemmConfig_::kScalarsPerStsB>
370 SharedStoreTileTraits;
373 static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
375 typedef WmmaGemmSharedLoadTileBTraits<
383 typename GemmConfig_::Warps,
385 GemmConfig_::InstructionShape::kH * Tile::kW,
387 Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
389 Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
391 typename GemmConfig_::InstructionShape>
392 SharedLoadTileTraits;
403 typename OutputTile_,
407 typename Accumulator_,
409 typename EpilogueFunctor_,
411 typename AccumulatorsPerWarp_,
413 typename InstructionShape_,
415 int kScalarsPerLdgA_,
417 int kScalarsPerLdgB_,
420 struct WmmaGemmTraitsHelper {
422 typedef WmmaGemmConfig<kLayoutA_,
427 AccumulatorsPerWarp_,
434 typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
436 typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
439 typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
442 typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
444 typedef TileStoreIterator<
typename GemmTileTraitsHelperA::SharedStoreTileTraits,
445 typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
448 SharedStoreIteratorA;
450 typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
454 typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
457 typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
459 typedef TileStoreIterator<
typename GemmTileTraitsHelperB::SharedStoreTileTraits,
460 typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
463 SharedStoreIteratorB;
465 typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
469 typedef TileLoadIterator<
typename GemmTileTraitsHelperA::SharedLoadTileTraits,
470 typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
474 typename GemmTileTraitsHelperA::WmmaMatrix,
478 typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
480 typedef TileLoadIterator<
typename GemmTileTraitsHelperB::SharedLoadTileTraits,
481 typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
485 typename GemmTileTraitsHelperB::WmmaMatrix,
489 typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
494 typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
497 typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
499 typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
502 typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
507 template <
typename OutputTile_,
typename DefaultShape_ = Shape<64, 32, 64> >
508 struct WmmaGemmAccumulatorsPerWarp {
520 typename OutputTile_ = Shape<64, 128, 128>,
522 typename ScalarC_ = float,
524 typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
526 typename Accumulator_ = ScalarC_,
528 typename AccumulatorsPerWarp_ =
typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
530 typename InstructionShape_ = Shape<16, 16, 16>,
532 int kScalarsPerLdgA_ = 8,
534 int kScalarsPerLdgB_ = 8,
536 typename Index_ = int,
538 typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
544 AccumulatorsPerWarp_,
549 struct WmmaGemmTraits :
public GemmTraits<
551 typename Helper_::GemmConfig,
553 typename Helper_::GlobalLoadStreamA,
555 typename Helper_::GlobalLoadStreamB,
557 typename Helper_::SharedLoadStreamA,
559 typename Helper_::SharedLoadStreamB,
561 typename Helper_::Epilogue,
563 IdentityBlockSwizzle,
567 typename Helper_::ClearAccumulators> {};
574 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_traits.h:93
Definition: load_store.h:42
Defines iterators for efficiently loading and storing to global memory.
Defines structural properties of complete GEMM computation.
Defines structural properties of WMMA GEMM's epilogue phase.
Definition: tile_iterator.h:62
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Defines iterators for efficiently loading and storing tiles to and from shared memory.
Definition: matrix_traits.h:36
Definition: tile_iterator.h:67
Definition: matrix_traits.h:43
Defines tile iterator traits for loading thread block-level tile from global memory.
Definition: matrix_traits.h:36
Kind
Definition: matrix_traits.h:36
Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
Definition: matrix_traits.h:43
Implements a software-pipelined efficient GEMM.
Defines structural properties of the GEMM epilogue.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:148
Defines conversion operations among Fragments of different base type.