Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
igemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include <cutlass/convert.h>
33 #include <cutlass/gemm/gemm.h>
43 #include <cutlass/reshape_tile.h>
44 
45 namespace cutlass {
46 namespace gemm {
47 
49 
50 template <
52  typename OutputTile_,
54  typename ScalarD_,
56  typename AccumulatorsPerThread_>
58  : public GemmConfig<
60  int8_t,
62  int8_t,
64  ScalarD_,
66  ScalarD_,
68  OutputTile_,
70  ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
72  4,
74  4,
76  16,
78  4,
80  4,
82  16,
84  1,
86  4,
88  1,
90  2> {};
91 
93 
94 template <typename OutputTile_, typename AccumulatorsPerThread_>
95 struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
96  : public GemmConfig<
98  int8_t,
100  int8_t,
102  int8_t,
104  int8_t,
106  OutputTile_,
108  ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
110  4,
112  4,
114  16,
116  4,
118  4,
120  16,
122  4,
124  4,
126  4,
128  2> {};
129 
131 
132 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
133 struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
134 
136 
137 template <typename GemmConfig_>
138 struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
139  : public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
142 
144  static int const kScalarsPerStsA = 16;
145 
149  // The layout.
151  // The pointer is float const.
152  int8_t const,
153  // The tile has size KxM in GEMM's terminology.
155  // The threads are distributed as warps x 32 (the traits may reorganize).
157  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
158  4>
160 
163  // The pointer is float.
164  int8_t,
165  // The tile has size KxM in GEMM's terminology.
166  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
167  // The threads are distributed as warps x 32 (the traits may reorganize).
168  typename GlobalTileTraits::Threads,
169  // The number of scalars per STS (STS.32 or STS.128, etc).
170  kScalarsPerStsA>
172 };
173 
175 
176 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
177 struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
178 
180 
181 template <typename GemmConfig_>
182 struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
183  : public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
186 
188  static int const kScalarsPerStsB = 16;
189 
193  // The layout.
195  // The pointer is float const.
196  int8_t const,
197  // The tile has size KxM in GEMM's terminology.
199  // The threads are distributed as warps x 32 (the traits may reorganize).
201  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
202  4>
204 
207  // The pointer is float.
208  int8_t,
209  // The tile has size KxM in GEMM's terminology.
210  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
211  // The threads are distributed as warps x 32 (the traits may reorganize).
212  typename GlobalTileTraits::Threads,
213  // The number of scalars per STS (STS.32 or STS.128, etc).
214  kScalarsPerStsB>
216 };
217 
219 
220 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
222 
223 template <typename Iterator_>
224 struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
226 };
227 
228 template <typename Iterator_>
229 struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
231 };
232 
234 
235 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
237 
238 template <typename Iterator_>
239 struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
241 };
242 
243 template <typename Iterator_>
244 struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
246 };
247 
249 
250 template <
252  MatrixLayout::Kind kLayoutA_,
254  MatrixLayout::Kind kLayoutB_,
256  typename OutputTile_,
258  typename ScalarD_,
260  typename EpilogueFunctor_,
262  typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
264  typename Index_ = int>
272 
277  typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
280  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
281  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
288 
292  // The default transformer for B.
293  typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
296  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
297  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
304 
306  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
307  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
315  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
316  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
323 
328 
331 };
332 
334 
335 template <typename ScalarD_>
337  typedef float Scalar;
338 };
339 
340 template <>
341 struct IgemmEpilogueScalar<int> {
342  typedef int Scalar;
343 };
344 
346 
347 template <
349  MatrixLayout::Kind kLayoutA_,
351  MatrixLayout::Kind kLayoutB_,
353  typename OutputTile_ = Shape<32, 128, 128>,
355  typename ScalarD_ = int,
359  typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
361  typename Index_ = int,
363  typename Helper_ = IgemmTraitsHelper<kLayoutA_,
364  kLayoutB_,
365  OutputTile_,
366  ScalarD_,
367  EpilogueFunctor_,
368  AccumulatorsPerThread_,
369  Index_> >
370 struct IgemmTraits : public GemmTraits<
371  // The config.
372  typename Helper_::GemmConfig,
373  // The stream to load A from global memory to shared memory.
374  typename Helper_::GlobalLoadStreamA,
375  // The stream to load B from global memory to shared memory.
376  typename Helper_::GlobalLoadStreamB,
377  // The stream to load A from shared memory.
378  typename Helper_::SharedLoadStreamA,
379  // The stream to load B from shared memory.
380  typename Helper_::SharedLoadStreamB,
381  // The epilogue.
382  typename Helper_::Epilogue,
383  // The block swizzle to reorganize the grid.
384  IdentityBlockSwizzle,
385  // The index.
386  Index_,
387  // The tool used to clear accumulators.
388  typename Helper_::ClearAccumulators> {};
389 
391 
392 } // namespace gemm
393 } // namespace cutlass
Definition: load_store.h:42
TileLoadIterator< typename GemmTileTraitsHelperB::SharedLoadTileTraits, typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
The iterator to load B from shared memory.
Definition: igemm_traits.h:319
Definition: convert.h:33
IgemmSwizzle< Iterator_ > Transformer
Definition: igemm_traits.h:230
Defines iterators for efficiently loading and storing to global memory.
GemmGlobalIteratorAb< typename GemmTileTraitsHelperA::GlobalTileTraits, Index_ > GlobalLoadIteratorA
The iterator to load A from global memory.
Definition: igemm_traits.h:275
Transposes a fragment of data containing packed 8-bit integer elements.
Copy< typename Iterator_::Fragment > Transformer
Definition: igemm_traits.h:240
Defines structural properties of complete GEMM computation.
GlobalLoadStream< GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: igemm_traits.h:303
Definition: igemm_traits.h:133
TileStoreIterator< typename GemmTileTraitsHelperB::SharedStoreTileTraits, typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
The iterator to store B to shared memory.
Definition: igemm_traits.h:300
IgemmTransformerB< GemmTileTraitsHelperB::kLayout, GlobalLoadIteratorB >::Transformer GlobalTransformerB
Definition: igemm_traits.h:294
Definition: igemm_epilogue.h:290
IgemmContiguousGlobalTileTraits< GemmOperand::kB, MatrixLayout::kRowMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, 4 > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^T.
Definition: igemm_traits.h:203
Definition: convert.h:69
GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ > Base
The base config.
Definition: igemm_traits.h:141
IgemmConfig< OutputTile_, ScalarD_, AccumulatorsPerThread_ > GemmConfig
The IGEMM config.
Definition: igemm_traits.h:267
Definition: gemm_shared_tile.h:38
Definition: tile_iterator.h:62
Implements matrix multiply accumulate operation of 8-bit integer data using DP4A instruction.
Definition: gemm_global_tile.h:159
GemmSharedStoreTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kH *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsB > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for B^N.
Definition: igemm_traits.h:215
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Definition: gemm_global_stream.h:161
Definition: gemm_traits.h:273
GemmGlobalIteratorAb< typename GemmTileTraitsHelperB::GlobalTileTraits, Index_ > GlobalLoadIteratorB
The iterator to load B from global memory.
Definition: igemm_traits.h:291
IgemmContiguousGlobalTileTraits< GemmOperand::kA, MatrixLayout::kColumnMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, 4 > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^N.
Definition: igemm_traits.h:159
int Scalar
Definition: igemm_traits.h:342
IgemmSwizzle< Iterator_ > Transformer
Definition: igemm_traits.h:245
Describes layouts of matrices.
Definition: matrix_traits.h:35
IgemmTileTraitsHelperB< kLayoutB_, GemmConfig > GemmTileTraitsHelperB
The GEMM config for B.
Definition: igemm_traits.h:271
Definition: igemm_swizzle.h:38
Definition: igemm_traits.h:177
Definition: igemm_traits.h:265
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:302
GlobalLoadStream< GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: igemm_traits.h:287
SharedLoadStream< SharedLoadIteratorB, Copy< typename SharedLoadIteratorB::Fragment > > SharedLoadStreamB
The stream to load B from shared memory.
Definition: igemm_traits.h:322
Defines iterators for efficiently loading and storing tiles to and from shared memory.
Definition: matrix_traits.h:36
IgemmTileTraitsHelperA< kLayoutA_, GemmConfig > GemmTileTraitsHelperA
The GEMM config for A.
Definition: igemm_traits.h:269
Definition: gemm_shared_stream.h:44
Defines a type for restructuring a tile.
TileLoadIterator< typename GemmTileTraitsHelperA::SharedLoadTileTraits, typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
The iterator to load A from shared memory.
Definition: igemm_traits.h:310
ClearAccumulators< typename MultiplyAdd::ScalarC > ClearAccumulators
The object to clear accumulators.
Definition: igemm_traits.h:327
Definition: gemm_traits.h:79
Definition: gemm_traits.h:137
Definition: matrix_traits.h:43
Definition: igemm_traits.h:57
Definition: igemm_traits.h:221
Definition: igemm_global_tile.h:50
float Scalar
Definition: igemm_traits.h:337
Definition: gemm_traits.h:428
Copy< typename Iterator_::Fragment > Transformer
Definition: igemm_traits.h:225
Definition: igemm_traits.h:370
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
GemmSharedStoreTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kW *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsA > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for A^N.
Definition: igemm_traits.h:171
Template performing matrix multiply-add operation within a thread.
Definition: thread_multiply_add.h:43
Definition: matrix_traits.h:36
IgemmEpilogue< IgemmEpilogueTraits< GemmConfig, EpilogueFunctor_ > > Epilogue
The epilogue.
Definition: igemm_traits.h:330
IgemmTransformerA< GemmTileTraitsHelperA::kLayout, GlobalLoadIteratorA >::Transformer GlobalTransformerA
The default transformer for A.
Definition: igemm_traits.h:278
Kind
Definition: matrix_traits.h:36
Definition: igemm_traits.h:236
TileStoreIterator< typename GemmTileTraitsHelperA::SharedStoreTileTraits, typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
The iterator to store A to shared memory.
Definition: igemm_traits.h:284
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:40
Definition: matrix_traits.h:43
Implements a software-pipelined efficient GEMM.
ReshapeThreads< Tile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:87
Defines structural properties of the GEMM epilogue.
Definition: igemm_traits.h:336
Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and floating-point o...
Defines conversion operations among Fragments of different base type.
GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ > Base
The base config.
Definition: igemm_traits.h:185
SharedLoadStream< SharedLoadIteratorA, Copy< typename SharedLoadIteratorA::Fragment > > SharedLoadStreamA
The stream to load A from shared memory.
Definition: igemm_traits.h:313
Implements tile iterators to partition the thread block tile into 2D subtiles and efficiently load ea...
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:620
GemmConfig::MultiplyAdd MultiplyAdd
The multiply-add functor.
Definition: igemm_traits.h:325