51 template <enum MatrixLayout::Kind,
typename GemmConfig_>
56 template <
typename GemmConfig_>
62 typedef typename GemmConfig_::ScalarA
Scalar;
79 GemmConfig_::kScalarsPerLdgA>
87 Shape<GemmConfig_::kStages,
88 GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
89 GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
93 GemmConfig_::kScalarsPerStsA>
101 typename GemmConfig_::OutputTile,
103 typename GemmConfig_::Warps,
105 typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
107 typename GemmConfig_::InstructionShape,
109 GemmConfig_::kStages,
111 GemmConfig_::kScalarsPerLdsA,
119 template <
typename GemmConfig_>
125 typedef typename GemmConfig_::ScalarA
Scalar;
140 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
142 GemmConfig_::kScalarsPerLdgA>
149 GlobalTileTraits::Threads::kW * kScalarsIn4B;
156 Shape<GemmConfig_::kStages,
157 GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
158 GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
162 GemmConfig_::kScalarsPerStsA,
164 kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
165 SharedStoreTileTraits;
172 typename GemmConfig_::OutputTile,
174 typename GemmConfig_::Warps,
176 typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
178 typename GemmConfig_::InstructionShape,
180 GemmConfig_::kStages,
182 GemmConfig_::kScalarsPerLdsA,
184 SharedStoreTileTraits::kSkew>
185 SharedLoadTileTraits;
190 template <enum MatrixLayout::Kind,
typename GemmConfig_>
195 template <
typename GemmConfig_>
201 typedef typename GemmConfig_::ScalarB
Scalar;
216 Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
218 GemmConfig_::kScalarsPerLdgB>
225 GlobalTileTraits::Threads::kW * kScalarsIn4B;
232 Shape<GemmConfig_::kStages,
233 GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
234 GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
238 GemmConfig_::kScalarsPerStsB,
240 kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
241 SharedStoreTileTraits;
248 typename GemmConfig_::OutputTile,
250 typename GemmConfig_::Warps,
252 typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
254 typename GemmConfig_::InstructionShape,
256 GemmConfig_::kStages,
258 GemmConfig_::kScalarsPerLdsB,
260 SharedStoreTileTraits::kSkew>
261 SharedLoadTileTraits;
266 template <
typename GemmConfig_>
272 typedef typename GemmConfig_::ScalarB
Scalar;
289 GemmConfig_::kScalarsPerLdgB>
297 Shape<GemmConfig_::kStages,
298 GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
299 GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
303 GemmConfig_::kScalarsPerStsB>
311 typename GemmConfig_::OutputTile,
313 typename GemmConfig_::Warps,
315 typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
317 typename GemmConfig_::InstructionShape,
319 GemmConfig_::kStages,
321 GemmConfig_::kScalarsPerLdsB,
331 typename GemmConfig_,
333 typename GlobalLoadStreamA_,
335 typename GlobalLoadStreamB_,
337 typename SharedLoadStreamA_,
339 typename SharedLoadStreamB_,
345 typename Index_ = int,
374 typedef typename GlobalLoadStreamA_::Scalar
ScalarA;
381 typedef typename GlobalLoadStreamB_::Scalar
ScalarB;
431 template <
typename GemmDesc_>
439 this->
grid = block_swizzle.get_grid_layout(
441 make_Coord_from_shape<OutputTile>());
445 Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
451 desc.A.leading_dim(),
461 desc.B.leading_dim(),
477 typename Epilogue::Scalar alpha,
482 typename Epilogue::Scalar beta,
504 typename Epilogue::Scalar alpha,
507 long long int batch_stride_A,
510 long long int batch_stride_B,
511 typename Epilogue::Scalar beta,
514 long long int batch_stride_C,
517 long long int batch_stride_D,
560 if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
561 SharedLoadStreamB::Iterator::kRequiresLoadFence) {
574 template <
typename GemmTileTraitsHelperA_,
typename GemmTileTraitsHelperB_,
typename Index_>
582 typedef TileStoreIterator<
typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
583 typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
600 typedef TileStoreIterator<
typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
601 typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
613 typedef TileLoadIterator<
typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
614 typename GemmTileTraitsHelperA_::Scalar,
621 typedef TileLoadIterator<
typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
622 typename GemmTileTraitsHelperB_::Scalar,
638 typename GemmConfig_,
642 typename Index_ = int,
654 typename Helper_::GlobalLoadStreamA,
656 typename Helper_::GlobalLoadStreamB,
658 typename Helper_::SharedLoadStreamA,
660 typename Helper_::SharedLoadStreamB,
664 IdentityBlockSwizzle,
668 ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
Epilogue::SharedStorage epilogue
Definition: gemm_traits.h:555
GEMM problem description.
Definition: gemm_desc.h:50
GlobalLoadStreamA_ GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:370
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:85
GlobalLoadStream< GemmOperand::kA, GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:592
Definition: load_store.h:41
SharedLoadStreamA_ SharedLoadStreamA
The iterator for A to load from shared memory.
Definition: gemm_traits.h:384
Definition: gemm_shared_tile.h:128
GlobalLoadStreamB_ GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:377
Definition: gemm_shared_tile.h:80
static int const kThreads
The numnber of threads.
Definition: gemm_config.h:103
TileStoreIterator< typename GemmTileTraitsHelperA_::SharedStoreTileTraits, typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
The iterator to store A to shared memory.
Definition: gemm_traits.h:586
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Epilogue::ScalarD ScalarD
Definition: gemm_traits.h:394
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, typename Epilogue::Scalar alpha, ScalarA const *d_a, Index lda, ScalarB const *d_b, Index ldb, typename Epilogue::Scalar beta, ScalarC const *d_c, Index ldc, ScalarD *d_d, Index ldd)
Helper to construct a GEMM params using a BLAS-like API.
Definition: gemm_traits.h:474
The storage in shared memory.
Definition: gemm_traits.h:551
SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
The stream to load B from shared memory.
Definition: gemm_traits.h:627
Definition: gemm_global_tile.h:70
Defines a structure containing shared storage for each pair.
Definition: gemm_stream_pair.h:91
GlobalLoadStream< GemmOperand::kB, GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:610
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:201
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kH *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsB > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for B^T.
Definition: gemm_traits.h:304
Definition: gemm_coord.h:43
GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ > This_
This traits.
Definition: gemm_traits.h:359
SharedLoadStreamB_ SharedLoadStreamB
The iterator for B to load from shared memory.
Definition: gemm_traits.h:386
Defines structures and helpers to launch CUDA kernels within CUTLASS.
GlobalLoadStreamPair< GlobalLoadStreamA, GlobalLoadStreamB, GemmConfig::kResidueInProlog > GlobalLoadStream
Assemble the global load streams for A/B.
Definition: gemm_traits.h:407
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^N.
Definition: gemm_traits.h:219
ThreadblockTileStorage threadblock_tile
Stores the threadblock tile.
Definition: gemm_traits.h:541
SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
The stream to load A from shared memory.
Definition: gemm_traits.h:619
Definition: gemm_shared_tile.h:38
GemmSharedLoadTileATraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsA, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for A^N.
Definition: gemm_traits.h:114
Epilogue_ Epilogue
The epilogue.
Definition: gemm_traits.h:391
GlobalLoadStreamA_::Scalar ScalarA
The scalar for A.
Definition: gemm_traits.h:374
Definition: tile_iterator.h:65
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^N.
Definition: gemm_traits.h:80
GlobalLoadStream::Params global_to_shared_stream
Parameters object for the global load stream.
Definition: gemm_traits.h:422
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:62
Definition: gemm_shared_tile.h:200
Definition: gemm_global_tile.h:163
Epilogue::ScalarC ScalarC
The scalars in the epilogue.
Definition: gemm_traits.h:393
GemmConfig::MultiplyAdd MultiplyAdd
The multiply-add functor.
Definition: gemm_traits.h:389
static CUTLASS_DEVICE void shared_load_fence(bool in_loop)
The memory fence for shared loads.
Definition: gemm_traits.h:559
GemmConfig_ GemmConfig
The configuration.
Definition: gemm_traits.h:365
Definition: gemm_global_stream.h:52
Definition: gemm_traits.h:191
Definition: clear_accumulators.h:38
Parameters object constructable on the host.
Definition: gemm_traits.h:416
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:173
Copy< typename GlobalLoadIteratorB::Fragment > GlobalTransformerB
The data converter for B before storing to shared memory.
Definition: gemm_traits.h:598
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:272
StreamB::Params stream_b
Parameters object for StreamB.
Definition: gemm_stream_pair.h:67
Defines data layouts of various matrix formats usable by TensorRef and other classes.
Definition: matrix_traits.h:156
GemmGlobalIteratorAb< typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_ > GlobalLoadIteratorB
The global iterator to load B from global memory.
Definition: gemm_traits.h:596
static bool const kResidueInProlog
If true, residue is computed in the prologue.
Definition: gemm_config.h:136
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:399
Definition: gemm_traits.h:539
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:50
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
Definition: matrix_traits.h:159
Defines a fragment based on a Shape<> template.
Structure containing the basic launch configuration of a CUDA kernel.
Definition: kernel_launch.h:38
ClearAccumulators_ ClearAccumulators
Clear the accumulators.
Definition: gemm_traits.h:401
Definition: gemm_shared_stream.h:45
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^T.
Definition: gemm_traits.h:143
Parameters object.
Definition: gemm_stream_pair.h:62
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
GemmCoord problem_size
GEMM problem size.
Definition: gemm_traits.h:419
Implements a software-pipelined efficient GEMM.
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, typename Epilogue::Scalar alpha, ScalarA const *d_a, Index lda, long long int batch_stride_A, ScalarB const *d_b, Index ldb, long long int batch_stride_B, typename Epilogue::Scalar beta, ScalarC const *d_c, Index ldc, long long int batch_stride_C, ScalarD *d_d, Index ldd, long long int batch_stride_D, Index batch_count)
Helper to construct a batched GEMM params.
Definition: gemm_traits.h:501
Defines abstractions for efficiently clearing accumulator tiles.
Definition: tensor_ref.h:131
SharedStreamPair< SharedLoadStreamA, SharedLoadStreamB > SharedStream
Assemble the shared load streams for A/B.
Definition: gemm_traits.h:413
static CUTLASS_DEVICE void shared_store_fence(bool in_loop)
The memory fence for shared stores.
Definition: gemm_traits.h:567
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:125
Manages a pair of tile allocations as if they are one allocation.
Definition: tile_allocation.h:100
Definition: gemm_traits.h:52
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: gemm_traits.h:432
Definition: matrix_traits.h:357
Definition: threadblock_swizzle.h:65
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kW *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsA > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for A^N.
Definition: gemm_traits.h:94
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:274
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:203
GlobalLoadStreamB_::Scalar ScalarB
The scalar for B.
Definition: gemm_traits.h:381
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Defines properties of GEMM computation that impose some constraints on caller.
Definition: gemm_traits.h:349
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
cutlass::gemm::Gemm< This_ > KernelClass
The struct that consumes this Traits.
Definition: gemm_traits.h:362
SharedStream::Params shared_stream
Parameters object for the shared load stream.
Definition: gemm_traits.h:425
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
BlockSwizzle_ BlockSwizzle
The block swizzle to reorganize the grid.
Definition: gemm_traits.h:397
TileLoadIterator< typename GemmTileTraitsHelperA_::SharedLoadTileTraits, typename GemmTileTraitsHelperA_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
The iterator to load A from shared memory.
Definition: gemm_traits.h:617
Definition: matrix_traits.h:159
TileLoadIterator< typename GemmTileTraitsHelperB_::SharedLoadTileTraits, typename GemmTileTraitsHelperB_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
The iterator to load B from shared memory.
Definition: gemm_traits.h:625
GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage
Memory needed to store the threadblock-scoped GEMM tile.
Definition: gemm_traits.h:410
dim3 block
CUDA threablock dimensions.
Definition: kernel_launch.h:44
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:83
Index_ Index
The index.
Definition: gemm_traits.h:399
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:64
TileStoreIterator< typename GemmTileTraitsHelperB_::SharedStoreTileTraits, typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
The iterator to store B to shared memory.
Definition: gemm_traits.h:604
Epilogue::Params epilogue
The params for the epilogue.
Definition: gemm_traits.h:428
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Defines a pair of GEMM tile streams.
The shared storage.
Definition: clear_accumulators.h:40
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
MainLoopSharedStorage main_loop
Definition: gemm_traits.h:553
static MatrixLayout::Kind const kLayoutA
The layout of A.
Definition: gemm_traits.h:372
dim3 grid
CUDA grid dimensions.
Definition: kernel_launch.h:41
Definition: matrix_traits.h:357
GlobalLoadStream::SharedStorage global_to_shared_stream
Storage for GEMM global stream.
Definition: gemm_traits.h:544
Parameters object passed to load iterators.
Definition: gemm_stream_pair.h:185
Defies functors for mapping blockIdx to partitions of the GEMM computation.
Definition: gemm_traits.h:575
Implements a software-pipelined efficient GEMM.
GemmGlobalIteratorAb< typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_ > GlobalLoadIteratorA
The global iterator to load A from global memory.
Definition: gemm_traits.h:578
GemmConfig::OutputTile OutputTile
The output tile.
Definition: gemm_traits.h:367
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Copy< typename GlobalLoadIteratorA::Fragment > GlobalTransformerA
The data converter for A before storing to shared memory.
Definition: gemm_traits.h:580
GemmSharedLoadTileBTraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsB, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for B^T.
Definition: gemm_traits.h:324
ClearAccumulators::SharedStorage clear
Storage for clearing accumulators.
Definition: gemm_traits.h:547
StreamA::Params stream_a
Parameters object for StreamA.
Definition: gemm_stream_pair.h:64
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^T.
Definition: gemm_traits.h:290
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Definition: gemm_traits.h:650
OutputTile_ OutputTile
The tile.
Definition: gemm_config.h:88
static MatrixLayout::Kind const kLayoutB
The layout of B.
Definition: gemm_traits.h:379
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:836
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:127