Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
hgemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/convert.h>
31 #include <cutlass/reshape_tile.h>
32 
33 #include <cutlass/gemm/gemm.h>
42 
43 namespace cutlass {
44 namespace gemm {
45 
47 
48 template <
50  typename OutputTile_,
52  typename AccumulatorsPerThread_,
54  int kScalarsPerLdgA_ = 2,
56  int kScalarsPerLdgB_ = 2>
58  : public GemmConfig<
60  half,
62  half,
64  half,
66  half,
68  OutputTile_,
70  ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, half, half, half>,
72  kScalarsPerLdgA_,
74  kScalarsPerLdgA_,
76  8,
78  kScalarsPerLdgB_,
80  kScalarsPerLdgB_,
82  8,
84  2,
86  8,
88  2,
90  2> {};
91 
93 
94 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
96 
97 template <typename Iterator_>
98 struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
100 };
101 
102 template <typename Iterator_>
103 struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
105 };
106 
108 
109 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
111 
112 template <typename Iterator_>
113 struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
115 };
116 
117 template <typename Iterator_>
118 struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
120 };
121 
123 
124 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
125 struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
126 
128 
129 template <typename GemmConfig_>
130 struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
131  : public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
134 
138  // The layout.
140  // The pointer.
141  half const,
142  // The tile has size MxK in GEMM's terminology.
144  // The threads are distributed as (threads / K ) x K (the traits may reorganize).
145  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
146  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
147  GemmConfig_::kScalarsPerLdgA>
149 
152  // The pointer.
153  half,
154  // The tile has size KxM in GEMM's terminology.
155  Shape<GemmConfig_::kStages,
156  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
157  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
158  // The threads are distributed as warps x 32(the traits may reorganize).
159  typename GlobalTileTraits::Threads,
160  // The number of scalars per STS (STS.32 or STS.128, etc).
161  2,
162  // The skew to avoid bank conflicts added in the tile W dimension.
163  128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
165 
168  // The pointer.
169  half const,
170  // The output tile size.
171  typename GemmConfig_::OutputTile,
172  // The number of warps.
173  typename GemmConfig_::Warps,
174  // The number of threads per warp.
175  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
176  // The shape of the FMA instruction.
177  typename GemmConfig_::InstructionShape,
178  // The number of stages.
179  GemmConfig_::kStages,
180  // The number of scalars per LDS.
181  8,
182  // The skew.
183  SharedStoreTileTraits::kSkew>
185 };
186 
188 
189 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
190 struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
191 
193 
194 template <typename GemmConfig_>
195 struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
196  : public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
199 
203  // The layout.
205  // The pointer.
206  half const,
207  // The tile has size KxN in GEMM's terminology.
209  // The threads are distributed as (threads / K) x K (the traits may reorganize).
210  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
211  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
212  GemmConfig_::kScalarsPerLdgB>
214 
217  // The pointer.
218  half,
219  // The tile has size KxN in GEMM's terminology.
220  Shape<GemmConfig_::kStages,
221  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
222  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
223  // The threads are distributed as (threads / K) x K (the traits may reorganize).
224  typename GlobalTileTraits::Threads,
225  // The number of scalars per STS (STS.32 or STS.128, etc).
226  2,
227  // The skew to avoid bank conflicts added in the tile W dimension.
228  128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
230 
233  // The pointer.
234  half const,
235  // The output tile size.
236  typename GemmConfig_::OutputTile,
237  // The number of warps.
238  typename GemmConfig_::Warps,
239  // The number of threads per warp.
240  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
241  // The shape of the FMA instruction.
242  typename GemmConfig_::InstructionShape,
243  // The number of stages.
244  GemmConfig_::kStages,
245  // The number of scalars per LDS.
246  8,
247  // The skew.
248  SharedStoreTileTraits::kSkew>
250 };
251 
253 
254 template <
256  MatrixLayout::Kind kLayoutA_,
258  MatrixLayout::Kind kLayoutB_,
260  typename OutputTile_,
262  typename EpilogueFunctor_,
264  typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
266  int kScalarsPerLdgA_ = 2,
268  int kScalarsPerLdgB_ = 2,
270  typename Index_ = int>
279 
284  typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
287  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
288  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
295 
299  // The default transformer for B.
300  typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
303  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
304  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
311 
313  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
314  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
321  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
322  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
328 
333 
338 };
339 
341 
342 template <
344  MatrixLayout::Kind kLayoutA_,
346  MatrixLayout::Kind kLayoutB_,
348  typename OutputTile_ = Shape<8, 128, 128>,
350  typename EpilogueFunctor_ = LinearScaling<half>,
352  typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
354  int kScalarsPerLdgA_ = 2,
356  int kScalarsPerLdgB_ = 2,
358  typename Index_ = int,
360  typename Helper_ = HgemmTraitsHelper<kLayoutA_,
361  kLayoutB_,
362  OutputTile_,
363  EpilogueFunctor_,
364  AccumulatorsPerThread_,
365  kScalarsPerLdgA_,
366  kScalarsPerLdgB_,
367  Index_> >
368 struct HgemmTraits : public GemmTraits<
369  // The config.
370  typename Helper_::GemmConfig,
371  // The stream to load A from global memory to shared memory.
372  typename Helper_::GlobalLoadStreamA,
373  // The stream to load B from global memory to shared memory.
374  typename Helper_::GlobalLoadStreamB,
375  // The stream to load A from shared memory.
376  typename Helper_::SharedLoadStreamA,
377  // The stream to load B from shared memory.
378  typename Helper_::SharedLoadStreamB,
379  // The epilogue.
380  typename Helper_::Epilogue,
381  // The block swizzle to reorganize the grid.
382  IdentityBlockSwizzle,
383  // The index.
384  Index_,
385  // The tool used to clear accumulators.
386  typename Helper_::ClearAccumulators> {};
387 
389 
390 } // namespace gemm
391 } // namespace cutlass
GemmGlobalIteratorAb< typename GemmTileTraitsHelperA::GlobalTileTraits, Index_ > GlobalLoadIteratorA
The iterator to load A from global memory.
Definition: hgemm_traits.h:282
Definition: load_store.h:42
HgemmSwizzle< Iterator_ > Transformer
Definition: hgemm_traits.h:119
Definition: convert.h:33
Definition: gemm_shared_tile.h:129
Definition: gemm_epilogue.h:53
Defines iterators for efficiently loading and storing to global memory.
GemmGlobalIteratorAb< typename GemmTileTraitsHelperB::GlobalTileTraits, Index_ > GlobalLoadIteratorB
The iterator to load B from global memory.
Definition: hgemm_traits.h:298
ClearAccumulators< typename MultiplyAdd::ScalarC > ClearAccumulators
The object to clear accumulators.
Definition: hgemm_traits.h:332
Defines structural properties of complete GEMM computation.
TileStoreIterator< typename GemmTileTraitsHelperA::SharedStoreTileTraits, typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
The iterator to store A to shared memory.
Definition: hgemm_traits.h:291
GlobalLoadStream< GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: hgemm_traits.h:294
HgemmCrosswiseGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, half const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^N.
Definition: hgemm_traits.h:213
Definition: hgemm_traits.h:95
GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > Base
The base config.
Definition: hgemm_traits.h:198
SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
The stream to load A from shared memory.
Definition: hgemm_traits.h:319
Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
Definition: hgemm_traits.h:99
Definition: hgemm_traits.h:368
HgemmSwizzle< Iterator_ > Transformer
Definition: hgemm_traits.h:104
Definition: tile_iterator.h:62
Definition: gemm_shared_tile.h:198
TileLoadIterator< typename GemmTileTraitsHelperB::SharedLoadTileTraits, typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
The iterator to load B from shared memory.
Definition: hgemm_traits.h:325
Definition: gemm_global_tile.h:159
GemmEpilogue< GemmEpilogueTraits > Epilogue
The epilogue.
Definition: hgemm_traits.h:337
HgemmTransformerA< GemmTileTraitsHelperA::kLayout, GlobalLoadIteratorA >::Transformer GlobalTransformerA
The default transformer for A.
Definition: hgemm_traits.h:285
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Definition: gemm_global_stream.h:161
Definition: gemm_traits.h:273
Definition: hgemm_traits.h:125
Describes layouts of matrices.
Definition: matrix_traits.h:35
SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
The stream to load B from shared memory.
Definition: hgemm_traits.h:327
Definition: hgemm_traits.h:110
GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > Base
The base config.
Definition: hgemm_traits.h:133
TileLoadIterator< typename GemmTileTraitsHelperA::SharedLoadTileTraits, typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
The iterator to load A from shared memory.
Definition: hgemm_traits.h:317
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:302
SimplifiedGemmEpilogueTraits< GemmConfig, EpilogueFunctor_, Index_ > GemmEpilogueTraits
The traits class for the epilogue.
Definition: hgemm_traits.h:335
Defines iterators for efficiently loading and storing tiles to and from shared memory.
Definition: matrix_traits.h:36
Definition: gemm_shared_stream.h:44
Defines a type for restructuring a tile.
Specialization implementing multiply-add operation on half-precision floating point fragments...
Definition: gemm_traits.h:79
Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in shared memory for...
Definition: gemm_traits.h:137
GemmSharedLoadTileBTraits< half const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, 8, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for B^N.
Definition: hgemm_traits.h:249
Definition: matrix_traits.h:43
HgemmConfig< OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_ > GemmConfig
The HGEMM config.
Definition: hgemm_traits.h:274
Definition: hgemm_traits.h:190
GlobalLoadStream< GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: hgemm_traits.h:310
GemmConfig::MultiplyAdd MultiplyAdd
The functor to do the multiply-add in the main loop.
Definition: hgemm_traits.h:330
HgemmTileTraitsHelperB< kLayoutB_, GemmConfig > GemmTileTraitsHelperB
The GEMM config for B.
Definition: hgemm_traits.h:278
Definition: gemm_traits.h:428
Definition: hgemm_global_tile.h:48
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: gemm_epilogue_traits.h:300
GemmSharedLoadTileATraits< half const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, 8, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for A^T.
Definition: hgemm_traits.h:184
HgemmTileTraitsHelperA< kLayoutA_, GemmConfig > GemmTileTraitsHelperA
The GEMM config for A.
Definition: hgemm_traits.h:276
Template performing matrix multiply-add operation within a thread.
Definition: thread_multiply_add.h:43
Definition: matrix_traits.h:36
Kind
Definition: matrix_traits.h:36
HgemmTransformerB< GemmTileTraitsHelperB::kLayout, GlobalLoadIteratorB >::Transformer GlobalTransformerB
Definition: hgemm_traits.h:301
Definition: hgemm_traits.h:271
HgemmCrosswiseGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, half const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^T.
Definition: hgemm_traits.h:148
Tile traits used to construct global tile iterator for HGEMM. This is intended to partition the threa...
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:40
Definition: convert.h:38
Definition: matrix_traits.h:43
Implements a software-pipelined efficient GEMM.
ReshapeThreads< Tile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:87
Defines structural properties of the GEMM epilogue.
Definition: hgemm_swizzle.h:40
Defines conversion operations among Fragments of different base type.
Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
Definition: hgemm_traits.h:114
Definition: hgemm_traits.h:57
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:620
TileStoreIterator< typename GemmTileTraitsHelperB::SharedStoreTileTraits, typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
The iterator to store B to shared memory.
Definition: hgemm_traits.h:307