Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_epilogue_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/wmma_matrix.h>
31 #ifdef CUTLASS_USE_WMMA_API
32 
33 #include <cutlass/convert.h>
34 #include <cutlass/coord.h>
40 #include <cutlass/reshape_tile.h>
41 #include <cutlass/tile_iterator.h>
42 
43 namespace cutlass {
44 namespace gemm {
45 
47 
48 template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
49 struct WmmaGemmEpilogueTraitsHelper {
51  typedef typename EpilogueFunctor_::Scalar Scalar;
53  typedef typename GemmConfig_::OutputTile OutputTile;
54 
56  static int const kWmmasPerH =
57  GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
59  typedef Shape<1, 1, kWmmasPerH> Iterations;
60  // The iteration strides in the H/W dimension.
61  typedef Shape<0, 0, 0> Delta;
63  typedef EpilogueFunctor_ Functor;
64 
66  typedef WmmaGemmSharedStoreTileDTraits<
67  // The output layout.
69  // The pointer is float.
70  typename Functor::Scalar,
71  // The output tile size.
72  typename GemmConfig_::OutputTile,
73  // The number of warps.
74  typename GemmConfig_::Warps,
75  // The shape of the instruction.
76  typename GemmConfig_::InstructionShape>
77  SharedStoreTileTraits;
78 
79  typedef WmmaMatrix<GemmOperand::kC,
81  Scalar,
82  typename GemmConfig_::InstructionShape>
83  WmmaMatrix;
84 
86  typedef TileStoreIterator<SharedStoreTileTraits,
87  typename SharedStoreTileTraits::Scalar,
90  Index_,
91  WmmaMatrix,
93  SharedStoreIteratorD;
94 
96  typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
97 
99  typedef WmmaGemmSharedLoadTileDTraits<
100  // The pointer.
101  typename Functor::Scalar,
102  // The tile size.
103  typename SharedStoreIteratorD::Tile,
104  // The number of threads.
105  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
106  // The number of scalars per LDS.
107  GemmConfig_::kScalarsPerLdsD>
108  SharedLoadTileTraits;
109 
111  typedef TileLoadIterator<SharedLoadTileTraits,
112  typename SharedLoadTileTraits::Scalar,
115  SharedLoadIteratorD;
116 
118  typedef WmmaGemmGlobalIteratorCdTraits<
119  // The pointer is float const.
120  typename GemmConfig_::ScalarC const,
121  // The tile has size (N / Iterations)xM in GEMM's terminology.
122  Shape<1,
123  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
124  GemmConfig_::OutputTile::kW>,
125  // The threads are distributed as warps x 32 (the traits may reorganize).
126  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
127  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
128  GemmConfig_::kScalarsPerLdgC>
129  GlobalLoadTileTraits;
130 
132  typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
134  typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
135 
137  typedef WmmaGemmGlobalIteratorCdTraits<
138  // The pointer is float.
139  typename GemmConfig_::ScalarD,
140  // The tile has size (N / Iterations)xM in GEMM's terminology.
141  Shape<1,
142  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
143  GemmConfig_::OutputTile::kW>,
144  // The threads are distributed as warps x 32 (the traits may reorganize).
145  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
146  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
147  GemmConfig_::kScalarsPerStgD>
148  GlobalStoreTileTraits;
149 
151  typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
153  typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
154 };
155 
157 
158 } // namespace gemm
159 } // namespace cutlass
160 
161 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:42
Definition: convert.h:33
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Implements the BLAS linear scaling function alpha*AB + beta*C.
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: tile_iterator.h:62
Definition: matrix_traits.h:43
Defines a type for restructuring a tile.
Definition: tile_iterator.h:67
Defines tile iterator traits for loading thread block-level tile from global memory.
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
Definition: matrix_traits.h:36
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Defines iterator traits for efficiently loading and storing fragment to and from shared memory...