31 #ifdef CUTLASS_USE_WMMA_API 48 template <
typename GemmConfig_,
typename EpilogueFunctor_,
typename Index_ =
int>
49 struct WmmaGemmEpilogueTraitsHelper {
51 typedef typename EpilogueFunctor_::Scalar Scalar;
53 typedef typename GemmConfig_::OutputTile OutputTile;
56 static int const kWmmasPerH =
57 GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
59 typedef Shape<1, 1, kWmmasPerH> Iterations;
61 typedef Shape<0, 0, 0> Delta;
63 typedef EpilogueFunctor_ Functor;
66 typedef WmmaGemmSharedStoreTileDTraits<
70 typename Functor::Scalar,
72 typename GemmConfig_::OutputTile,
74 typename GemmConfig_::Warps,
76 typename GemmConfig_::InstructionShape>
77 SharedStoreTileTraits;
82 typename GemmConfig_::InstructionShape>
86 typedef TileStoreIterator<SharedStoreTileTraits,
87 typename SharedStoreTileTraits::Scalar,
96 typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
99 typedef WmmaGemmSharedLoadTileDTraits<
101 typename Functor::Scalar,
103 typename SharedStoreIteratorD::Tile,
105 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
107 GemmConfig_::kScalarsPerLdsD>
108 SharedLoadTileTraits;
111 typedef TileLoadIterator<SharedLoadTileTraits,
112 typename SharedLoadTileTraits::Scalar,
118 typedef WmmaGemmGlobalIteratorCdTraits<
120 typename GemmConfig_::ScalarC
const,
124 GemmConfig_::OutputTile::kW>,
126 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
128 GemmConfig_::kScalarsPerLdgC>
129 GlobalLoadTileTraits;
132 typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
134 typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
137 typedef WmmaGemmGlobalIteratorCdTraits<
139 typename GemmConfig_::ScalarD,
143 GemmConfig_::OutputTile::kW>,
145 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
147 GemmConfig_::kScalarsPerStgD>
148 GlobalStoreTileTraits;
151 typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
153 typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
161 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:42
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Implements the BLAS linear scaling function alpha*AB + beta*C.
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: tile_iterator.h:62
Definition: matrix_traits.h:43
Defines a type for restructuring a tile.
Definition: tile_iterator.h:67
Defines tile iterator traits for loading thread block-level tile from global memory.
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
Definition: matrix_traits.h:36
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Defines iterator traits for efficiently loading and storing fragment to and from shared memory...