Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/convert.h>
36 #include <cutlass/matrix_traits.h>
37 #include <cutlass/reshape_tile.h>
38 #include <cutlass/tile_iterator.h>
39 
40 namespace cutlass {
41 namespace gemm {
42 
44 
45 template <
47  typename ScalarA_,
49  typename ScalarB_,
51  typename ScalarC_,
53  typename ScalarD_,
55  typename OutputTile_,
57  typename MultiplyAdd_,
59  int kScalarsPerLdgA_,
61  int kScalarsPerStsA_,
63  int kScalarsPerLdsA_,
65  int kScalarsPerLdgB_,
67  int kScalarsPerStsB_,
69  int kScalarsPerLdsB_,
71  int kScalarsPerLdgCAndStgD_,
73  int kScalarsPerStsD_,
75  int kScalarsPerLdsD_,
77  int kStages_>
78 
79 struct GemmConfig {
80  //
82  typedef ScalarA_ ScalarA;
84  typedef ScalarB_ ScalarB;
86  typedef ScalarC_ ScalarC;
88  typedef ScalarD_ ScalarD;
89 
91  typedef OutputTile_ OutputTile;
93  typedef MultiplyAdd_ MultiplyAdd;
100 
104  static int const kWarpSize = cutlass::kWarpSize;
107 
109  static int const kScalarsPerLdgA = kScalarsPerLdgA_;
110  static int const kScalarsPerStsA = kScalarsPerStsA_;
111  static int const kScalarsPerLdsA = kScalarsPerLdsA_;
112 
114  static int const kScalarsPerLdgB = kScalarsPerLdgB_;
115  static int const kScalarsPerStsB = kScalarsPerStsB_;
116  static int const kScalarsPerLdsB = kScalarsPerLdsB_;
117 
119  static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
120 
122  static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
123  static int const kScalarsPerStsD = kScalarsPerStsD_;
124  static int const kScalarsPerLdsD = kScalarsPerLdsD_;
125 
127  static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
128  static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
129 
131  static int const kStages = kStages_;
132 };
133 
135 
136 template <enum MatrixLayout::Kind, typename GemmConfig_>
138 
140 
141 template <typename GemmConfig_>
142 struct GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
145 
147  typedef typename GemmConfig_::ScalarA Scalar;
149  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
150 
152  typedef GemmGlobalTileTraits<
153  // That's A.
155  // A is column-major.
157  // The pointer is float const.
158  Scalar const,
159  // The tile has size KxM in GEMM's terminology.
161  // The threads are distributed as warps x 32 (the traits may reorganize).
163  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
164  GemmConfig_::kScalarsPerLdgA>
166 
169  // The pointer is float.
171  // The tile has size KxM in GEMM's terminology.
172  Shape<GemmConfig_::kStages,
173  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
174  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
175  // The threads are distributed as warps x 32 (the traits may reorganize).
176  typename GlobalTileTraits::Threads,
177  // The number of scalars per STS (STS.32 or STS.128, etc).
178  GemmConfig_::kScalarsPerStsA>
180 
183  // The pointer is float const.
184  MultiplyAddScalar const,
185  // The output tile size.
186  typename GemmConfig_::OutputTile,
187  // The number of warps.
188  typename GemmConfig_::Warps,
189  // The number of threads per warp.
190  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
191  // The shape of the FMA instruction.
192  typename GemmConfig_::InstructionShape,
193  // The number of stages.
194  GemmConfig_::kStages,
195  // The number of scalars per LDS.
196  GemmConfig_::kScalarsPerLdsA,
197  // The skew.
198  0>
200 };
201 
203 
204 template <typename GemmConfig_>
205 struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
208 
210  typedef typename GemmConfig_::ScalarA Scalar;
212  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
213 
215  typedef GemmGlobalTileTraits<
216  // That's A.
218  // A is row-major.
220  // The pointer is float const.
221  Scalar const,
222  // The tile has size MxK in GEMM's terminology.
224  // The threads are distributed as (threads / K) x K (the traits may reorganize).
225  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
226  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
227  GemmConfig_::kScalarsPerLdgA>
229 
231  static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
234  // The pointer is float.
236  // The tile has size KxM in GEMM's terminology.
237  Shape<GemmConfig_::kStages,
238  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
239  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
240  // The threads are distributed as (threads / K) x K (the traits may reorganize).
241  typename GlobalTileTraits::Threads,
242  // The number of scalars per STS.
243  GemmConfig_::kScalarsPerStsA,
244  // The skew to avoid bank conflicts added in the tile W dimension.
245  128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
246  GlobalTileTraits::Threads::kW * kScalarsIn4B>
248 
251  // The pointer is float const.
252  MultiplyAddScalar const,
253  // The output tile size.
254  typename GemmConfig_::OutputTile,
255  // The number of warps.
256  typename GemmConfig_::Warps,
257  // The number of threads per warp.
258  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
259  // The shape of the FMA instruction.
260  typename GemmConfig_::InstructionShape,
261  // The number of stages.
262  GemmConfig_::kStages,
263  // The number of scalars per LDS.
264  GemmConfig_::kScalarsPerLdsA,
265  // The skew.
266  SharedStoreTileTraits::kSkew>
268 };
269 
271 
272 template <enum MatrixLayout::Kind, typename GemmConfig_>
274 
276 
277 template <typename GemmConfig_>
278 struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
281 
283  typedef typename GemmConfig_::ScalarB Scalar;
285  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
286 
288  typedef GemmGlobalTileTraits<
289  // That's B.
291  // B is column-major.
293  // The pointer is float const.
294  Scalar const,
295  // The tile has size MxK in GEMM's terminology.
297  // The threads are distributed as (threads / K) x K (the traits may reorganize).
298  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
299  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
300  GemmConfig_::kScalarsPerLdgB>
302 
304  static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
307  // The pointer is float.
309  // The tile has size KxN in GEMM's terminology.
310  Shape<GemmConfig_::kStages,
311  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
312  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
313  // The threads are distributed as (threads / K) x K (the traits may reorganize).
314  typename GlobalTileTraits::Threads,
315  // The number of scalars per STS.
316  GemmConfig_::kScalarsPerStsB,
317  // The skew to avoid bank conflicts added in the tile W dimension.
318  128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
319  GlobalTileTraits::Threads::kW * kScalarsIn4B>
321 
324  // The pointer is float const.
325  MultiplyAddScalar const,
326  // The output tile size.
327  typename GemmConfig_::OutputTile,
328  // The number of warps.
329  typename GemmConfig_::Warps,
330  // The number of threads per warp.
331  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
332  // The shape of the FMA instruction.
333  typename GemmConfig_::InstructionShape,
334  // The number of stages.
335  GemmConfig_::kStages,
336  // The number of scalars per LDS.
337  GemmConfig_::kScalarsPerLdsB,
338  // The skew.
339  SharedStoreTileTraits::kSkew>
341 };
342 
344 
345 template <typename GemmConfig_>
346 struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
349 
351  typedef typename GemmConfig_::ScalarB Scalar;
353  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
354 
356  typedef GemmGlobalTileTraits<
357  // That's B.
359  // B is row-major.
361  // The pointer is float const.
362  Scalar const,
363  // The tile has size KxN in GEMM's terminology.
365  // The threads are distributed as warps x 32 (the traits may reorganize).
367  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
368  GemmConfig_::kScalarsPerLdgB>
370 
373  // The pointer is float.
375  // The tile has size KxN in GEMM's terminology.
376  Shape<GemmConfig_::kStages,
377  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
378  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
379  // The threads are distributed as warps x 32 (the traits may reorganize).
380  typename GlobalTileTraits::Threads,
381  // The number of scalars per STS (STS.32 or STS.128, etc).
382  GemmConfig_::kScalarsPerStsB>
384 
387  // The pointer is float const.
388  MultiplyAddScalar const,
389  // The output tile size.
390  typename GemmConfig_::OutputTile,
391  // The number of warps.
392  typename GemmConfig_::Warps,
393  // The number of threads per warp.
394  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
395  // The shape of the FMA instruction.
396  typename GemmConfig_::InstructionShape,
397  // The number of stages.
398  GemmConfig_::kStages,
399  // The number of scalars per LDS.
400  GemmConfig_::kScalarsPerLdsB,
401  // The skew.
402  0>
404 };
405 
407 
408 template <
410  typename GemmConfig_,
412  typename GlobalLoadStreamA_,
414  typename GlobalLoadStreamB_,
416  typename SharedLoadStreamA_,
418  typename SharedLoadStreamB_,
420  typename Epilogue_,
422  typename BlockSwizzle_ = IdentityBlockSwizzle,
424  typename Index_ = int,
427 
428 struct GemmTraits {
430  typedef GemmConfig_ GemmConfig;
433 
435  typedef GlobalLoadStreamA_ GlobalLoadStreamA;
437  static MatrixLayout::Kind const kLayoutA = GlobalLoadStreamA::kLayout;
439  typedef typename GlobalLoadStreamA_::Scalar ScalarA;
440 
442  typedef GlobalLoadStreamB_ GlobalLoadStreamB;
444  static MatrixLayout::Kind const kLayoutB = GlobalLoadStreamB::kLayout;
446  typedef typename GlobalLoadStreamB_::Scalar ScalarB;
447 
449  typedef SharedLoadStreamA_ SharedLoadStreamA;
451  typedef SharedLoadStreamB_ SharedLoadStreamB;
452 
454  typedef typename GlobalLoadStreamA::SharedStoreStorage SharedStoreStorageA;
455  // Btw, make sure we did not messed up with the size of the storage.
456  static_assert(sizeof(SharedStoreStorageA) == sizeof(typename SharedLoadStreamA::SharedStorage),
457  "");
458 
460  typedef typename GlobalLoadStreamB::SharedStoreStorage SharedStoreStorageB;
461  // Btw, make sure we did not messed up with the size of the storage.
462  static_assert(sizeof(SharedStoreStorageB) == sizeof(typename SharedLoadStreamB::SharedStorage),
463  "");
464 
466  typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
468  typedef Epilogue_ Epilogue;
470  typedef typename Epilogue::ScalarC ScalarC;
471  typedef typename Epilogue::ScalarD ScalarD;
472 
474  typedef BlockSwizzle_ BlockSwizzle;
476  typedef Index_ Index;
478  typedef ClearAccumulators_ ClearAccumulators;
479 
481  struct Params {
483  Index m, n, k;
485  typename GlobalLoadStreamA::Params global_stream_a;
487  typename GlobalLoadStreamB::Params global_stream_b;
489  typename SharedLoadStreamA::Params shared_stream_a;
491  typename SharedLoadStreamB::Params shared_stream_b;
493  typename Epilogue::Params epilogue;
494 
496  template <typename GemmDesc_>
497  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
498  // Set the problem size.
499  this->m = desc.m;
500  this->n = desc.n;
501  this->k = desc.k;
502 
503  // Initialize the iterator for A.
504  int error_code =
505  global_stream_a.initialize(reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
506 
507  if (error_code) {
508  return error_code;
509  }
510 
511  // Initialize the iterator for B.
512  error_code = global_stream_b.initialize(reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
513 
514  if (error_code) {
515  return error_code;
516  }
517 
518  // The epilogue.
519  return epilogue.initialize(desc);
520  }
521  };
522 
523  // The storage for A.
524  template <typename GlobalLoadStream_, typename SharedLoadStream_>
526  // The storage needed by the global stream.
527  typename GlobalLoadStream_::SharedStorage global;
528  // The storage needed by the shared stream.
529  typename SharedLoadStream_::SharedStorage shared;
530  };
531 
532  // The storage for the main loop + prologue.
534  // The storage to shuffle the A matrix in shared memory.
536  // The storage to shuffle the B matrix in shared memory.
538  // The storage to clear the accumulators if needed.
540  };
541 
544  // The storage for the main loop.
546  // The storage for the epilogue.
547  typename Epilogue::SharedStorage epilogue;
548  };
549 
553  CUTLASS_DEVICE GlobalLoadStream(Params const& params,
554  SharedStorage& shared_storage,
555  dim3 const& block)
556  : stream_a(params.global_stream_a,
557  shared_storage.main_loop.stream_a.global,
558  cutlass::make_Coord(0, params.k, params.m),
559  cutlass::make_Coord(0, 0, block.x)),
560  stream_b(params.global_stream_b,
561  shared_storage.main_loop.stream_b.global,
562  cutlass::make_Coord(0, params.k, params.n),
563  make_Coord(0, 0, block.y)) {}
564 
566  CUTLASS_DEVICE void copy() {
567  stream_a.copy();
568  stream_b.copy();
569  }
570 
572  CUTLASS_DEVICE void commit() {
573  stream_a.commit();
574  stream_b.commit();
575  }
576 
578  CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
579  stream_a.residue(k, skip_clear);
580  stream_b.residue(k, skip_clear);
581  }
582 
587  };
588 
592  CUTLASS_DEVICE SharedLoadStream(Params const& params, SharedStorage& shared_storage) {
593  stream_a.initialize(params.shared_stream_a, shared_storage.main_loop.stream_a.shared);
594  stream_b.initialize(params.shared_stream_b, shared_storage.main_loop.stream_b.shared);
595  }
596 
598  CUTLASS_DEVICE void copy(int step) {
599  stream_a.copy(step, fetched_a[step % 2]);
600  stream_b.copy(step, fetched_b[step % 2]);
601  }
602 
604  CUTLASS_DEVICE void commit(int step) {
605  stream_a.commit(fetched_a[step % 2], transformed_a[step % 2]);
606  stream_b.commit(fetched_b[step % 2], transformed_b[step % 2]);
607  }
608 
610  CUTLASS_DEVICE typename SharedLoadStreamA::Fragment const& fragment_a(int step) const {
611  return transformed_a[step % 2];
612  }
613 
615  CUTLASS_DEVICE typename SharedLoadStreamB::Fragment const& fragment_b(int step) const {
616  return transformed_b[step % 2];
617  }
618 
620  CUTLASS_DEVICE void inc_stage() {
621  stream_a.inc_stage();
622  stream_b.inc_stage();
623  }
624 
628  typename SharedLoadStreamA::FetchedFragment fetched_a[2];
630  typename SharedLoadStreamA::TransformedFragment transformed_a[2];
634  typename SharedLoadStreamB::FetchedFragment fetched_b[2];
636  typename SharedLoadStreamB::TransformedFragment transformed_b[2];
637  };
638 
640  static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
641  if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
642  SharedLoadStreamB::Iterator::kRequiresLoadFence) {
643  __syncthreads();
644  }
645  }
646 
648  static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
649 };
650 
652 
653 template <typename GemmTileTraitsHelperA_, typename GemmTileTraitsHelperB_, typename Index_>
661  typedef TileStoreIterator<typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
662  typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
669 
676  typedef TileStoreIterator<typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
677  typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
684 
686  typedef TileLoadIterator<typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
687  typename GemmTileTraitsHelperA_::Scalar,
694  typedef TileLoadIterator<typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
695  typename GemmTileTraitsHelperB_::Scalar,
701 };
702 
704 
705 template <
707  MatrixLayout::Kind kLayoutA_,
709  MatrixLayout::Kind kLayoutB_,
711  typename GemmConfig_,
713  typename Epilogue_,
715  typename Index_ = int,
716  // The configuration for the A matrix.
717  typename GemmTileTraitsHelperA_ = GemmTileTraitsHelperA<kLayoutA_, GemmConfig_>,
718  // The configuration for the B matrix.
719  typename GemmTileTraitsHelperB_ = GemmTileTraitsHelperB<kLayoutB_, GemmConfig_>,
720  // The helper class to create the streams and iterators.
721  typename Helper_ =
724  // The config.
725  GemmConfig_,
726  // The stream to load A from global memory to shared memory.
727  typename Helper_::GlobalLoadStreamA,
728  // The stream to load B from global memory to shared memory.
729  typename Helper_::GlobalLoadStreamB,
730  // The stream to load A from shared memory.
731  typename Helper_::SharedLoadStreamA,
732  // The stream to load B from shared memory.
733  typename Helper_::SharedLoadStreamB,
734  // The epilogue.
735  Epilogue_,
736  // The block swizzle to reorganize the grid.
737  IdentityBlockSwizzle,
738  // The index.
739  Index_,
740  // The tool used to clear accumulators.
741  ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
742 };
743 
745 
746 } // namespace gemm
747 } // namespace cutlass
Index n
Definition: gemm_traits.h:483
static int const kWarpSize
The default warp size (32 threads per warp).
Definition: gemm_traits.h:104
Epilogue::SharedStorage epilogue
Definition: gemm_traits.h:547
static int const kScalarsPerStsA
Definition: gemm_traits.h:110
GemmSharedLoadTileBTraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsB, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for B^N.
Definition: gemm_traits.h:340
ScalarA_ ScalarA
The scalar for A.
Definition: gemm_traits.h:82
GlobalLoadStreamA_ GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:435
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:98
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_traits.h:93
static int const kAccumulatorsPerLdsA
The number of accumulators that are going to be fed from one LDS A/B.
Definition: gemm_traits.h:127
Definition: load_store.h:42
static int const kScalarsPerLdsA
Definition: gemm_traits.h:111
SharedLoadStreamA_ SharedLoadStreamA
The iterator for A to load from shared memory.
Definition: gemm_traits.h:449
MultiplyAdd::InstructionShape InstructionShape
The shape of the instruction.
Definition: gemm_traits.h:95
Definition: convert.h:33
SharedLoadStreamA::Params shared_stream_a
The params for the A stream from shared memory.
Definition: gemm_traits.h:489
Definition: gemm_shared_tile.h:129
GlobalLoadStreamB_ GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:442
CUTLASS_DEVICE void inc_stage()
Increment the stage.
Definition: gemm_traits.h:620
TileStoreIterator< typename GemmTileTraitsHelperA_::SharedStoreTileTraits, typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
The iterator to store A to shared memory.
Definition: gemm_traits.h:665
static int const kScalarsPerLdsB
Definition: gemm_traits.h:116
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Epilogue::ScalarD ScalarD
Definition: gemm_traits.h:471
The storage in shared memory.
Definition: gemm_traits.h:543
SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
The stream to load B from shared memory.
Definition: gemm_traits.h:700
Index k
Definition: gemm_traits.h:483
Definition: gemm_global_tile.h:70
SharedLoadStreamA::FetchedFragment fetched_a[2]
The fragments to fetch A.
Definition: gemm_traits.h:628
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:283
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kH *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsB > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for B^T.
Definition: gemm_traits.h:383
SharedLoadStreamB_ SharedLoadStreamB
The iterator for B to load from shared memory.
Definition: gemm_traits.h:451
static int const kScalarsPerStgD
The number of scalars per STS/LDS/STG for D.
Definition: gemm_traits.h:122
CUTLASS_DEVICE void copy(int step)
Trigger the copies from shared memory to registers.
Definition: gemm_traits.h:598
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^N.
Definition: gemm_traits.h:301
Definition: convert.h:69
A template defining Fragment Concept.
Definition: fragment.h:99
SharedLoadStreamA stream_a
The stream for A.
Definition: gemm_traits.h:626
SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
The stream to load A from shared memory.
Definition: gemm_traits.h:692
Definition: gemm_shared_tile.h:38
ScalarC_ ScalarC
The scalar for C.
Definition: gemm_traits.h:86
CUTLASS_DEVICE void copy()
Trigger the copies from shared memory to registers.
Definition: gemm_traits.h:566
GemmSharedLoadTileATraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsA, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for A^N.
Definition: gemm_traits.h:199
Epilogue_ Epilogue
The epilogue.
Definition: gemm_traits.h:468
GlobalLoadStreamA_::Scalar ScalarA
The scalar for A.
Definition: gemm_traits.h:439
Definition: tile_iterator.h:62
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^N.
Definition: gemm_traits.h:165
ShapeDiv< OutputTile, AccumulatorsPerWarp >::Shape Warps
The number of warps.
Definition: gemm_traits.h:102
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:147
Definition: gemm_shared_tile.h:198
GlobalLoadStreamB::SharedStoreStorage SharedStoreStorageB
The shared storage for B.
Definition: gemm_traits.h:457
Definition: gemm_global_tile.h:159
Epilogue::ScalarC ScalarC
The scalars in the epilogue.
Definition: gemm_traits.h:470
GlobalLoadStream< GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:683
SharedLoadStreamB stream_b
The stream for B.
Definition: gemm_traits.h:632
Assemble the shared load stream for A/B.
Definition: gemm_traits.h:590
GlobalLoadStreamB stream_b
The stream for B.
Definition: gemm_traits.h:586
GemmConfig::MultiplyAdd MultiplyAdd
The multiply-add functor.
Definition: gemm_traits.h:463
static CUTLASS_DEVICE void shared_load_fence(bool in_loop)
The memory fence for shared loads.
Definition: gemm_traits.h:640
GemmConfig_ GemmConfig
The configuration.
Definition: gemm_traits.h:430
Definition: gemm_global_stream.h:161
SharedLoadStreamB::TransformedFragment transformed_b[2]
The fragments to transform B.
Definition: gemm_traits.h:636
Definition: gemm_traits.h:273
GlobalLoadStreamA stream_a
The stream for A.
Definition: gemm_traits.h:584
GemmSharedLoadTileATraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsA, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for A^T.
Definition: gemm_traits.h:267
Definition: clear_accumulators.h:38
StreamSharedStorage< GlobalLoadStreamB, SharedLoadStreamB > stream_b
Definition: gemm_traits.h:537
The params.
Definition: gemm_traits.h:481
static int const kScalarsPerLdgA
The number of scalars per LDG/STS/LDS for A.
Definition: gemm_traits.h:109
CUTLASS_DEVICE SharedLoadStreamB::Fragment const & fragment_b(int step) const
The fragment B.
Definition: gemm_traits.h:615
Copy< typename GlobalLoadIteratorB::Fragment > GlobalTransformerB
The data converter for B before storing to shared memory.
Definition: gemm_traits.h:674
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:351
Describes layouts of matrices.
Definition: matrix_traits.h:35
GemmGlobalIteratorAb< typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_ > GlobalLoadIteratorB
The global iterator to load B from global memory.
Definition: gemm_traits.h:672
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:302
Definition: matrix_traits.h:36
CUTLASS_DEVICE void residue(Index k, bool skip_clear=false)
Execute the residue code.
Definition: gemm_traits.h:578
MultiplyAdd::Accumulators Accumulators
The accumulators.
Definition: gemm_traits.h:99
ClearAccumulators_ ClearAccumulators
Clear the accumulators.
Definition: gemm_traits.h:478
Definition: gemm_shared_stream.h:44
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^T.
Definition: gemm_traits.h:228
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
static int const kScalarsPerStsB
Definition: gemm_traits.h:115
Defines abstractions for efficiently clearing accumulator tiles.
Definition: gemm_traits.h:79
Assemble the global load streams for A/B.
Definition: gemm_traits.h:551
static int const kScalarsPerStsD
Definition: gemm_traits.h:123
static CUTLASS_DEVICE void shared_store_fence(bool in_loop)
The memory fence for shared stores.
Definition: gemm_traits.h:648
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:210
Definition: gemm_traits.h:137
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: gemm_traits.h:497
GlobalLoadStream_::SharedStorage global
Definition: gemm_traits.h:527
Definition: matrix_traits.h:43
Definition: identity_block_swizzle.h:37
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kW *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsA > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for A^N.
Definition: gemm_traits.h:179
ScalarB_ ScalarB
The scalar for B.
Definition: gemm_traits.h:84
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:353
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:285
GlobalLoadStreamB_::Scalar ScalarB
The scalar for B.
Definition: gemm_traits.h:446
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
GlobalLoadStreamA::SharedStoreStorage SharedStoreStorageA
The shared storage for A.
Definition: gemm_traits.h:454
GlobalLoadStream< GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:668
#define static_assert(__e, __m)
Definition: platform.h:145
Definition: gemm_traits.h:428
MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp
The number of accumulators per warp.
Definition: gemm_traits.h:97
SharedLoadStreamA::TransformedFragment transformed_a[2]
The fragments to transform A.
Definition: gemm_traits.h:630
SharedLoadStream_::SharedStorage shared
Definition: gemm_traits.h:529
GlobalLoadStreamB::Params global_stream_b
The params for the B stream.
Definition: gemm_traits.h:487
SharedLoadStreamB::FetchedFragment fetched_b[2]
The fragments to fetch B.
Definition: gemm_traits.h:634
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
static int const kScalarsPerLdgC
The number of scalars per LDG for C.
Definition: gemm_traits.h:119
ScalarD_ ScalarD
The scalar for D.
Definition: gemm_traits.h:88
static int const kThreads
The numnber of threads.
Definition: gemm_traits.h:106
Defies functors for mapping blockIdx to partitions of the GEMM computation.
Index m
The dimensions of the GEMM.
Definition: gemm_traits.h:483
BlockSwizzle_ BlockSwizzle
The block swizzle to reorganize the grid.
Definition: gemm_traits.h:474
TileLoadIterator< typename GemmTileTraitsHelperA_::SharedLoadTileTraits, typename GemmTileTraitsHelperA_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
The iterator to load A from shared memory.
Definition: gemm_traits.h:690
Definition: matrix_traits.h:36
TileLoadIterator< typename GemmTileTraitsHelperB_::SharedLoadTileTraits, typename GemmTileTraitsHelperB_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
The iterator to load B from shared memory.
Definition: gemm_traits.h:698
CUTLASS_DEVICE SharedLoadStream(Params const &params, SharedStorage &shared_storage)
Ctor.
Definition: gemm_traits.h:592
CUTLASS_DEVICE GlobalLoadStream(Params const &params, SharedStorage &shared_storage, dim3 const &block)
Ctor.
Definition: gemm_traits.h:553
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:96
Index_ Index
The index.
Definition: gemm_traits.h:476
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:149
TileStoreIterator< typename GemmTileTraitsHelperB_::SharedStoreTileTraits, typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
The iterator to store B to shared memory.
Definition: gemm_traits.h:680
Epilogue::Params epilogue
The params for the epilogue.
Definition: gemm_traits.h:493
Kind
Definition: matrix_traits.h:36
GlobalLoadStreamA::Params global_stream_a
The params for the A stream.
Definition: gemm_traits.h:485
The shared storage.
Definition: clear_accumulators.h:40
CUTLASS_DEVICE void commit(int step)
Commit the data.
Definition: gemm_traits.h:604
static int const kScalarsPerLdsD
Definition: gemm_traits.h:124
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
MainLoopSharedStorage main_loop
Definition: gemm_traits.h:545
static MatrixLayout::Kind const kLayoutA
The layout of A.
Definition: gemm_traits.h:437
OutputTile_ OutputTile
The tile.
Definition: gemm_traits.h:91
static int const kScalarsPerLdgB
The number of scalars per LDG/STS/LDS for B.
Definition: gemm_traits.h:114
Definition: matrix_traits.h:43
Definition: gemm_traits.h:654
ReshapeThreads< Tile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:87
GemmGlobalIteratorAb< typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_ > GlobalLoadIteratorA
The global iterator to load A from global memory.
Definition: gemm_traits.h:657
GemmConfig::OutputTile OutputTile
The output tile.
Definition: gemm_traits.h:432
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Copy< typename GlobalLoadIteratorA::Fragment > GlobalTransformerA
The data converter for A before storing to shared memory.
Definition: gemm_traits.h:659
CUTLASS_DEVICE void commit()
Commit the data.
Definition: gemm_traits.h:572
GemmSharedLoadTileBTraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsB, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for B^T.
Definition: gemm_traits.h:403
ClearAccumulators::SharedStorage clear
Definition: gemm_traits.h:539
StreamSharedStorage< GlobalLoadStreamA, SharedLoadStreamA > stream_a
Definition: gemm_traits.h:535
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^T.
Definition: gemm_traits.h:369
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Defines conversion operations among Fragments of different base type.
SharedLoadStreamB::Params shared_stream_b
The params for the B stream from shared memory.
Definition: gemm_traits.h:491
Definition: gemm_traits.h:723
CUTLASS_DEVICE SharedLoadStreamA::Fragment const & fragment_a(int step) const
The fragment A.
Definition: gemm_traits.h:610
static MatrixLayout::Kind const kLayoutB
The layout of B.
Definition: gemm_traits.h:444
static int const kAccumulatorsPerLdsB
Definition: gemm_traits.h:128
static int const kStages
The number of stages in shared memory to implement double, triple, more-buffering.
Definition: gemm_traits.h:131
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:620
ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: thread_multiply_add.h:51
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:212