31 #ifdef CUTLASS_USE_WMMA_API 48 template <
typename GemmConfig_,
typename EpilogueFunctor_,
typename Index_ =
int>
49 struct WmmaGemmEpilogueTraitsHelper {
51 typedef typename EpilogueFunctor_::Scalar Scalar;
53 typedef typename GemmConfig_::OutputTile OutputTile;
56 static int const kWmmasPerH =
57 GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
59 typedef Shape<1, 1, kWmmasPerH> Iterations;
61 typedef Shape<0, 0, 0> Delta;
63 typedef EpilogueFunctor_ Functor;
66 typedef WmmaGemmSharedStoreTileDTraits<
70 typename Functor::Scalar,
72 typename GemmConfig_::OutputTile,
74 typename GemmConfig_::Warps,
76 typename GemmConfig_::InstructionShape>
77 SharedStoreTileTraits;
82 typename GemmConfig_::InstructionShape>
86 typedef TileStoreIterator<SharedStoreTileTraits,
87 typename SharedStoreTileTraits::Scalar,
96 typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
99 typedef WmmaGemmSharedLoadTileDTraits<
101 typename Functor::Scalar,
103 typename SharedStoreIteratorD::Tile,
105 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
107 GemmConfig_::kScalarsPerLdsD>
108 SharedLoadTileTraits;
111 typedef TileLoadIterator<SharedLoadTileTraits,
112 typename SharedLoadTileTraits::Scalar,
118 typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
121 typedef WmmaGemmGlobalIteratorCdTraits<
123 typename GemmConfig_::ScalarC
const,
127 GemmConfig_::OutputTile::kW>,
129 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
131 GemmConfig_::kScalarsPerLdgC>
132 GlobalLoadTileTraits;
135 typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
137 typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
140 typedef WmmaGemmGlobalIteratorCdTraits<
142 typename GemmConfig_::ScalarD,
146 GemmConfig_::OutputTile::kW>,
148 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
150 GemmConfig_::kScalarsPerStgD>
151 GlobalStoreTileTraits;
154 typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
156 typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
164 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Implements the BLAS linear scaling function alpha*AB + beta*C.
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: load_store.h:48
Definition: tile_iterator.h:65
Definition: matrix_traits.h:357
Defines a type for restructuring a tile.
Defines tile iterator traits for loading thread block-level tile from global memory.
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
Definition: matrix_traits.h:159
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Defines iterator traits for efficiently loading and storing fragment to and from shared memory...