Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/convert.h"
31 #include "cutlass/matrix_traits.h"
32 #include "cutlass/reshape_tile.h"
34 #include "cutlass/tile_iterator.h"
35 #include "cutlass/kernel_launch.h"
36 
39 #include "cutlass/gemm/gemm_desc.h"
45 #include "cutlass/gemm/gemm.h"
46 namespace cutlass {
47 namespace gemm {
48 
50 
51 template <enum MatrixLayout::Kind, typename GemmConfig_>
53 
55 
56 template <typename GemmConfig_>
57 struct GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
60 
62  typedef typename GemmConfig_::ScalarA Scalar;
64  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
65 
67  typedef GemmGlobalTileTraits<
68  // That's A.
70  // A is column-major.
72  // The pointer is float const.
73  Scalar const,
74  // The tile has size KxM in GEMM's terminology.
76  // The threads are distributed as warps x 32 (the traits may reorganize).
78  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
79  GemmConfig_::kScalarsPerLdgA>
81 
84  // The pointer is float.
86  // The tile has size KxM in GEMM's terminology.
87  Shape<GemmConfig_::kStages,
88  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
89  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
90  // The threads are distributed as warps x 32 (the traits may reorganize).
92  // The number of scalars per STS (STS.32 or STS.128, etc).
93  GemmConfig_::kScalarsPerStsA>
95 
98  // The pointer is float const.
99  MultiplyAddScalar const,
100  // The output tile size.
101  typename GemmConfig_::OutputTile,
102  // The number of warps.
103  typename GemmConfig_::Warps,
104  // The number of threads per warp.
105  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
106  // The shape of the FMA instruction.
107  typename GemmConfig_::InstructionShape,
108  // The number of stages.
109  GemmConfig_::kStages,
110  // The number of scalars per LDS.
111  GemmConfig_::kScalarsPerLdsA,
112  // The skew.
113  0>
115 };
116 
118 
119 template <typename GemmConfig_>
120 struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
123 
125  typedef typename GemmConfig_::ScalarA Scalar;
127  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
128 
130  typedef GemmGlobalTileTraits<
131  // That's A.
133  // A is row-major.
135  // The pointer is float const.
136  Scalar const,
137  // The tile has size MxK in GEMM's terminology.
139  // The threads are distributed as (threads / K) x K (the traits may reorganize).
140  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
141  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
142  GemmConfig_::kScalarsPerLdgA>
144 
146  static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
148  static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
149  GlobalTileTraits::Threads::kW * kScalarsIn4B;
150 
153  // The pointer is float.
155  // The tile has size KxM in GEMM's terminology.
156  Shape<GemmConfig_::kStages,
157  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
158  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
159  // The threads are distributed as (threads / K) x K (the traits may reorganize).
160  typename GlobalTileTraits::Threads,
161  // The number of scalars per STS.
162  GemmConfig_::kScalarsPerStsA,
163  // The skew to avoid bank conflicts added in the tile W dimension.
164  kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
165  SharedStoreTileTraits;
166 
169  // The pointer is float const.
170  MultiplyAddScalar const,
171  // The output tile size.
172  typename GemmConfig_::OutputTile,
173  // The number of warps.
174  typename GemmConfig_::Warps,
175  // The number of threads per warp.
176  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
177  // The shape of the FMA instruction.
178  typename GemmConfig_::InstructionShape,
179  // The number of stages.
180  GemmConfig_::kStages,
181  // The number of scalars per LDS.
182  GemmConfig_::kScalarsPerLdsA,
183  // The skew.
184  SharedStoreTileTraits::kSkew>
185  SharedLoadTileTraits;
186 };
187 
189 
190 template <enum MatrixLayout::Kind, typename GemmConfig_>
192 
194 
195 template <typename GemmConfig_>
196 struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
199 
201  typedef typename GemmConfig_::ScalarB Scalar;
203  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
204 
206  typedef GemmGlobalTileTraits<
207  // That's B.
209  // B is column-major.
211  // The pointer is float const.
212  Scalar const,
213  // The tile has size MxK in GEMM's terminology.
215  // The threads are distributed as (threads / K) x K (the traits may reorganize).
216  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
217  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
218  GemmConfig_::kScalarsPerLdgB>
220 
222  static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
224  static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
225  GlobalTileTraits::Threads::kW * kScalarsIn4B;
226 
229  // The pointer is float.
231  // The tile has size KxN in GEMM's terminology.
232  Shape<GemmConfig_::kStages,
233  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
234  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
235  // The threads are distributed as (threads / K) x K (the traits may reorganize).
236  typename GlobalTileTraits::Threads,
237  // The number of scalars per STS.
238  GemmConfig_::kScalarsPerStsB,
239  // The skew to avoid bank conflicts added in the tile W dimension.
240  kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
241  SharedStoreTileTraits;
242 
245  // The pointer is float const.
246  MultiplyAddScalar const,
247  // The output tile size.
248  typename GemmConfig_::OutputTile,
249  // The number of warps.
250  typename GemmConfig_::Warps,
251  // The number of threads per warp.
252  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
253  // The shape of the FMA instruction.
254  typename GemmConfig_::InstructionShape,
255  // The number of stages.
256  GemmConfig_::kStages,
257  // The number of scalars per LDS.
258  GemmConfig_::kScalarsPerLdsB,
259  // The skew.
260  SharedStoreTileTraits::kSkew>
261  SharedLoadTileTraits;
262 };
263 
265 
266 template <typename GemmConfig_>
267 struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
270 
272  typedef typename GemmConfig_::ScalarB Scalar;
274  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
275 
277  typedef GemmGlobalTileTraits<
278  // That's B.
280  // B is row-major.
282  // The pointer is float const.
283  Scalar const,
284  // The tile has size KxN in GEMM's terminology.
286  // The threads are distributed as warps x 32 (the traits may reorganize).
288  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
289  GemmConfig_::kScalarsPerLdgB>
291 
294  // The pointer is float.
296  // The tile has size KxN in GEMM's terminology.
297  Shape<GemmConfig_::kStages,
298  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
299  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
300  // The threads are distributed as warps x 32 (the traits may reorganize).
301  typename GlobalTileTraits::Threads,
302  // The number of scalars per STS (STS.32 or STS.128, etc).
303  GemmConfig_::kScalarsPerStsB>
305 
308  // The pointer is float const.
309  MultiplyAddScalar const,
310  // The output tile size.
311  typename GemmConfig_::OutputTile,
312  // The number of warps.
313  typename GemmConfig_::Warps,
314  // The number of threads per warp.
315  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
316  // The shape of the FMA instruction.
317  typename GemmConfig_::InstructionShape,
318  // The number of stages.
319  GemmConfig_::kStages,
320  // The number of scalars per LDS.
321  GemmConfig_::kScalarsPerLdsB,
322  // The skew.
323  0>
325 };
326 
328 
329 template <
331  typename GemmConfig_,
333  typename GlobalLoadStreamA_,
335  typename GlobalLoadStreamB_,
337  typename SharedLoadStreamA_,
339  typename SharedLoadStreamB_,
341  typename Epilogue_,
343  typename BlockSwizzle_ = IdentityBlockSwizzle,
345  typename Index_ = int,
348 
349 struct GemmTraits {
351  typedef GemmTraits<GemmConfig_,
352  GlobalLoadStreamA_,
353  GlobalLoadStreamB_,
354  SharedLoadStreamA_,
355  SharedLoadStreamB_,
356  Epilogue_,
357  BlockSwizzle_,
358  Index_,
359  ClearAccumulators_> This_;
360 
363 
365  typedef GemmConfig_ GemmConfig;
368 
370  typedef GlobalLoadStreamA_ GlobalLoadStreamA;
372  static MatrixLayout::Kind const kLayoutA = GlobalLoadStreamA::kLayout;
374  typedef typename GlobalLoadStreamA_::Scalar ScalarA;
375 
377  typedef GlobalLoadStreamB_ GlobalLoadStreamB;
379  static MatrixLayout::Kind const kLayoutB = GlobalLoadStreamB::kLayout;
381  typedef typename GlobalLoadStreamB_::Scalar ScalarB;
382 
384  typedef SharedLoadStreamA_ SharedLoadStreamA;
386  typedef SharedLoadStreamB_ SharedLoadStreamB;
387 
391  typedef Epilogue_ Epilogue;
393  typedef typename Epilogue::ScalarC ScalarC;
394  typedef typename Epilogue::ScalarD ScalarD;
395 
397  typedef BlockSwizzle_ BlockSwizzle;
399  typedef Index_ Index;
401  typedef ClearAccumulators_ ClearAccumulators;
402 
408 
411 
414 
417 
420 
423 
426 
428  typename Epilogue::Params epilogue;
429 
431  template <typename GemmDesc_>
432  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
433  // Set the problem size.
434  problem_size = desc.problem_size;
435 
436  // Compute grid dimensions
437  BlockSwizzle block_swizzle;
438  this->block = dim3(GemmConfig::kThreads);
439  this->grid = block_swizzle.get_grid_layout(
440  problem_size,
441  make_Coord_from_shape<OutputTile>());
442 
443  // Compute offset to residue.
444  Index gemm_k = problem_size[0];
445  Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
446 
447  // Initialize parameters objects for
448  int error_code = global_to_shared_stream.stream_a.initialize(
449  desc.A.data(),
450  desc.batch_stride_A,
451  desc.A.leading_dim(),
452  offset_to_residue
453  );
454  if (error_code) {
455  return error_code;
456  }
457 
458  error_code = global_to_shared_stream.stream_b.initialize(
459  desc.B.data(),
460  desc.batch_stride_B,
461  desc.B.leading_dim(),
462  offset_to_residue
463  );
464 
465  if (error_code) {
466  return error_code;
467  }
468 
469  // The epilogue.
470  return epilogue.initialize(desc);
471  }
472 
475  Index n,
476  Index k,
477  typename Epilogue::Scalar alpha,
478  ScalarA const* d_a,
479  Index lda,
480  ScalarB const* d_b,
481  Index ldb,
482  typename Epilogue::Scalar beta,
483  ScalarC const* d_c,
484  Index ldc,
485  ScalarD* d_d,
486  Index ldd) {
488  GemmCoord(k, n, m, 1),
489  alpha,
490  TensorRef<ScalarA const, 2>(d_a, lda),
491  TensorRef<ScalarB const, 2>(d_b, ldb),
492  beta,
493  TensorRef<ScalarC const, 2>(d_c, ldc),
494  TensorRef<ScalarD, 2>(d_d, ldd)
495  );
496 
497  return this->initialize(desc);
498  }
499 
502  Index n,
503  Index k,
504  typename Epilogue::Scalar alpha,
505  ScalarA const* d_a,
506  Index lda,
507  long long int batch_stride_A,
508  ScalarB const* d_b,
509  Index ldb,
510  long long int batch_stride_B,
511  typename Epilogue::Scalar beta,
512  ScalarC const* d_c,
513  Index ldc,
514  long long int batch_stride_C,
515  ScalarD* d_d,
516  Index ldd,
517  long long int batch_stride_D,
518  Index batch_count) {
519 
521  GemmCoord(k, n, m, batch_count),
522  alpha,
523  TensorRef<ScalarA const, 2>(d_a, lda),
524  batch_stride_A,
525  TensorRef<ScalarB const, 2>(d_b, ldb),
526  batch_stride_B,
527  beta,
528  TensorRef<ScalarC const, 2>(d_c, ldc),
529  batch_stride_C,
530  TensorRef<ScalarD, 2>(d_d, ldd),
531  batch_stride_D
532  );
533 
534  return this->initialize(desc);
535  }
536  };
537 
538  // The storage for the main loop + prologue.
542 
545 
548  };
549 
552  // The storage for the main loop.
554  // The storage for the epilogue.
555  typename Epilogue::SharedStorage epilogue;
556  };
557 
559  static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
560  if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
561  SharedLoadStreamB::Iterator::kRequiresLoadFence) {
562  __syncthreads();
563  }
564  }
565 
567  static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
568  __syncthreads();
569  }
570 };
571 
573 
574 template <typename GemmTileTraitsHelperA_, typename GemmTileTraitsHelperB_, typename Index_>
582  typedef TileStoreIterator<typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
583  typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
593 
600  typedef TileStoreIterator<typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
601  typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
611 
613  typedef TileLoadIterator<typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
614  typename GemmTileTraitsHelperA_::Scalar,
621  typedef TileLoadIterator<typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
622  typename GemmTileTraitsHelperB_::Scalar,
628 };
629 
631 
632 template <
634  MatrixLayout::Kind kLayoutA_,
636  MatrixLayout::Kind kLayoutB_,
638  typename GemmConfig_,
640  typename Epilogue_,
642  typename Index_ = int,
643  // The configuration for the A matrix.
644  typename GemmTileTraitsHelperA_ = GemmTileTraitsHelperA<kLayoutA_, GemmConfig_>,
645  // The configuration for the B matrix.
646  typename GemmTileTraitsHelperB_ = GemmTileTraitsHelperB<kLayoutB_, GemmConfig_>,
647  // The helper class to create the streams and iterators.
648  typename Helper_ =
651  // The config.
652  GemmConfig_,
653  // The stream to load A from global memory to shared memory.
654  typename Helper_::GlobalLoadStreamA,
655  // The stream to load B from global memory to shared memory.
656  typename Helper_::GlobalLoadStreamB,
657  // The stream to load A from shared memory.
658  typename Helper_::SharedLoadStreamA,
659  // The stream to load B from shared memory.
660  typename Helper_::SharedLoadStreamB,
661  // The epilogue.
662  Epilogue_,
663  // The block swizzle to reorganize the grid.
664  IdentityBlockSwizzle,
665  // The index.
666  Index_,
667  // The tool used to clear accumulators.
668  ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
669 };
670 
672 
673 } // namespace gemm
674 } // namespace cutlass
Epilogue::SharedStorage epilogue
Definition: gemm_traits.h:555
GEMM problem description.
Definition: gemm_desc.h:50
GlobalLoadStreamA_ GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:370
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:85
GlobalLoadStream< GemmOperand::kA, GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:592
Definition: load_store.h:41
SharedLoadStreamA_ SharedLoadStreamA
The iterator for A to load from shared memory.
Definition: gemm_traits.h:384
Definition: convert.h:33
Definition: gemm_shared_tile.h:128
GlobalLoadStreamB_ GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:377
static int const kThreads
The numnber of threads.
Definition: gemm_config.h:103
TileStoreIterator< typename GemmTileTraitsHelperA_::SharedStoreTileTraits, typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
The iterator to store A to shared memory.
Definition: gemm_traits.h:586
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Epilogue::ScalarD ScalarD
Definition: gemm_traits.h:394
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, typename Epilogue::Scalar alpha, ScalarA const *d_a, Index lda, ScalarB const *d_b, Index ldb, typename Epilogue::Scalar beta, ScalarC const *d_c, Index ldc, ScalarD *d_d, Index ldd)
Helper to construct a GEMM params using a BLAS-like API.
Definition: gemm_traits.h:474
The storage in shared memory.
Definition: gemm_traits.h:551
SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
The stream to load B from shared memory.
Definition: gemm_traits.h:627
Definition: gemm_global_tile.h:70
Defines a structure containing shared storage for each pair.
Definition: gemm_stream_pair.h:91
GlobalLoadStream< GemmOperand::kB, GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:610
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:201
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kH *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsB > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for B^T.
Definition: gemm_traits.h:304
Definition: gemm_coord.h:43
GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ > This_
This traits.
Definition: gemm_traits.h:359
SharedLoadStreamB_ SharedLoadStreamB
The iterator for B to load from shared memory.
Definition: gemm_traits.h:386
Defines structures and helpers to launch CUDA kernels within CUTLASS.
GlobalLoadStreamPair< GlobalLoadStreamA, GlobalLoadStreamB, GemmConfig::kResidueInProlog > GlobalLoadStream
Assemble the global load streams for A/B.
Definition: gemm_traits.h:407
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^N.
Definition: gemm_traits.h:219
Definition: convert.h:69
ThreadblockTileStorage threadblock_tile
Stores the threadblock tile.
Definition: gemm_traits.h:541
SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
The stream to load A from shared memory.
Definition: gemm_traits.h:619
Definition: gemm_shared_tile.h:38
GemmSharedLoadTileATraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsA, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for A^N.
Definition: gemm_traits.h:114
Epilogue_ Epilogue
The epilogue.
Definition: gemm_traits.h:391
GlobalLoadStreamA_::Scalar ScalarA
The scalar for A.
Definition: gemm_traits.h:374
Definition: tile_iterator.h:65
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^N.
Definition: gemm_traits.h:80
GlobalLoadStream::Params global_to_shared_stream
Parameters object for the global load stream.
Definition: gemm_traits.h:422
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:62
Definition: gemm_shared_tile.h:200
Definition: gemm_global_tile.h:163
Epilogue::ScalarC ScalarC
The scalars in the epilogue.
Definition: gemm_traits.h:393
GemmConfig::MultiplyAdd MultiplyAdd
The multiply-add functor.
Definition: gemm_traits.h:389
static CUTLASS_DEVICE void shared_load_fence(bool in_loop)
The memory fence for shared loads.
Definition: gemm_traits.h:559
GemmConfig_ GemmConfig
The configuration.
Definition: gemm_traits.h:365
Definition: gemm_global_stream.h:52
Definition: gemm_traits.h:191
Definition: clear_accumulators.h:38
Parameters object constructable on the host.
Definition: gemm_traits.h:416
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:173
Copy< typename GlobalLoadIteratorB::Fragment > GlobalTransformerB
The data converter for B before storing to shared memory.
Definition: gemm_traits.h:598
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:272
StreamB::Params stream_b
Parameters object for StreamB.
Definition: gemm_stream_pair.h:67
Definition: gemm.h:92
Defines data layouts of various matrix formats usable by TensorRef and other classes.
Definition: matrix_traits.h:156
GemmGlobalIteratorAb< typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_ > GlobalLoadIteratorB
The global iterator to load B from global memory.
Definition: gemm_traits.h:596
static bool const kResidueInProlog
If true, residue is computed in the prologue.
Definition: gemm_config.h:136
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:399
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:50
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
Definition: matrix_traits.h:159
Defines a fragment based on a Shape<> template.
Structure containing the basic launch configuration of a CUDA kernel.
Definition: kernel_launch.h:38
ClearAccumulators_ ClearAccumulators
Clear the accumulators.
Definition: gemm_traits.h:401
Definition: gemm_shared_stream.h:45
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^T.
Definition: gemm_traits.h:143
Parameters object.
Definition: gemm_stream_pair.h:62
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
GemmCoord problem_size
GEMM problem size.
Definition: gemm_traits.h:419
Implements a software-pipelined efficient GEMM.
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, typename Epilogue::Scalar alpha, ScalarA const *d_a, Index lda, long long int batch_stride_A, ScalarB const *d_b, Index ldb, long long int batch_stride_B, typename Epilogue::Scalar beta, ScalarC const *d_c, Index ldc, long long int batch_stride_C, ScalarD *d_d, Index ldd, long long int batch_stride_D, Index batch_count)
Helper to construct a batched GEMM params.
Definition: gemm_traits.h:501
Defines abstractions for efficiently clearing accumulator tiles.
Definition: tensor_ref.h:131
SharedStreamPair< SharedLoadStreamA, SharedLoadStreamB > SharedStream
Assemble the shared load streams for A/B.
Definition: gemm_traits.h:413
static CUTLASS_DEVICE void shared_store_fence(bool in_loop)
The memory fence for shared stores.
Definition: gemm_traits.h:567
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:125
Manages a pair of tile allocations as if they are one allocation.
Definition: tile_allocation.h:100
Definition: gemm_traits.h:52
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: gemm_traits.h:432
Definition: matrix_traits.h:357
Definition: threadblock_swizzle.h:65
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kW *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsA > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for A^N.
Definition: gemm_traits.h:94
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:274
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:203
GlobalLoadStreamB_::Scalar ScalarB
The scalar for B.
Definition: gemm_traits.h:381
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Defines properties of GEMM computation that impose some constraints on caller.
Definition: gemm_traits.h:349
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
cutlass::gemm::Gemm< This_ > KernelClass
The struct that consumes this Traits.
Definition: gemm_traits.h:362
SharedStream::Params shared_stream
Parameters object for the shared load stream.
Definition: gemm_traits.h:425
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
BlockSwizzle_ BlockSwizzle
The block swizzle to reorganize the grid.
Definition: gemm_traits.h:397
TileLoadIterator< typename GemmTileTraitsHelperA_::SharedLoadTileTraits, typename GemmTileTraitsHelperA_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
The iterator to load A from shared memory.
Definition: gemm_traits.h:617
Definition: matrix_traits.h:159
TileLoadIterator< typename GemmTileTraitsHelperB_::SharedLoadTileTraits, typename GemmTileTraitsHelperB_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
The iterator to load B from shared memory.
Definition: gemm_traits.h:625
GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage
Memory needed to store the threadblock-scoped GEMM tile.
Definition: gemm_traits.h:410
dim3 block
CUDA threablock dimensions.
Definition: kernel_launch.h:44
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:83
Index_ Index
The index.
Definition: gemm_traits.h:399
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:64
TileStoreIterator< typename GemmTileTraitsHelperB_::SharedStoreTileTraits, typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
The iterator to store B to shared memory.
Definition: gemm_traits.h:604
Epilogue::Params epilogue
The params for the epilogue.
Definition: gemm_traits.h:428
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Defines a pair of GEMM tile streams.
The shared storage.
Definition: clear_accumulators.h:40
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
MainLoopSharedStorage main_loop
Definition: gemm_traits.h:553
static MatrixLayout::Kind const kLayoutA
The layout of A.
Definition: gemm_traits.h:372
dim3 grid
CUDA grid dimensions.
Definition: kernel_launch.h:41
Definition: matrix_traits.h:357
GlobalLoadStream::SharedStorage global_to_shared_stream
Storage for GEMM global stream.
Definition: gemm_traits.h:544
Parameters object passed to load iterators.
Definition: gemm_stream_pair.h:185
Defies functors for mapping blockIdx to partitions of the GEMM computation.
Definition: gemm_traits.h:575
Implements a software-pipelined efficient GEMM.
GemmGlobalIteratorAb< typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_ > GlobalLoadIteratorA
The global iterator to load A from global memory.
Definition: gemm_traits.h:578
GemmConfig::OutputTile OutputTile
The output tile.
Definition: gemm_traits.h:367
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Copy< typename GlobalLoadIteratorA::Fragment > GlobalTransformerA
The data converter for A before storing to shared memory.
Definition: gemm_traits.h:580
GemmSharedLoadTileBTraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsB, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for B^T.
Definition: gemm_traits.h:324
ClearAccumulators::SharedStorage clear
Storage for clearing accumulators.
Definition: gemm_traits.h:547
StreamA::Params stream_a
Parameters object for StreamA.
Definition: gemm_stream_pair.h:64
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^T.
Definition: gemm_traits.h:290
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Definition: gemm_traits.h:650
OutputTile_ OutputTile
The tile.
Definition: gemm_config.h:88
static MatrixLayout::Kind const kLayoutB
The layout of B.
Definition: gemm_traits.h:379
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:836
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:127