Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_epilogue_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/wmma_matrix.h"
31 #ifdef CUTLASS_USE_WMMA_API
32 
33 #include "cutlass/convert.h"
34 #include "cutlass/coord.h"
40 #include "cutlass/reshape_tile.h"
41 #include "cutlass/tile_iterator.h"
42 
43 namespace cutlass {
44 namespace gemm {
45 
47 
48 template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
49 struct WmmaGemmEpilogueTraitsHelper {
51  typedef typename EpilogueFunctor_::Scalar Scalar;
53  typedef typename GemmConfig_::OutputTile OutputTile;
54 
56  static int const kWmmasPerH =
57  GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
59  typedef Shape<1, 1, kWmmasPerH> Iterations;
60  // The iteration strides in the H/W dimension.
61  typedef Shape<0, 0, 0> Delta;
63  typedef EpilogueFunctor_ Functor;
64 
66  typedef WmmaGemmSharedStoreTileDTraits<
67  // The output layout.
69  // The pointer is float.
70  typename Functor::Scalar,
71  // The output tile size.
72  typename GemmConfig_::OutputTile,
73  // The number of warps.
74  typename GemmConfig_::Warps,
75  // The shape of the instruction.
76  typename GemmConfig_::InstructionShape>
77  SharedStoreTileTraits;
78 
79  typedef WmmaMatrix<GemmOperand::kC,
81  Scalar,
82  typename GemmConfig_::InstructionShape>
83  WmmaMatrix;
84 
86  typedef TileStoreIterator<SharedStoreTileTraits,
87  typename SharedStoreTileTraits::Scalar,
90  Index_,
91  WmmaMatrix,
93  SharedStoreIteratorD;
94 
96  typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
97 
99  typedef WmmaGemmSharedLoadTileDTraits<
100  // The pointer.
101  typename Functor::Scalar,
102  // The tile size.
103  typename SharedStoreIteratorD::Tile,
104  // The number of threads.
105  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
106  // The number of scalars per LDS.
107  GemmConfig_::kScalarsPerLdsD>
108  SharedLoadTileTraits;
109 
111  typedef TileLoadIterator<SharedLoadTileTraits,
112  typename SharedLoadTileTraits::Scalar,
115  SharedLoadIteratorD;
116 
118  typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
119 
121  typedef WmmaGemmGlobalIteratorCdTraits<
122  // The pointer is float const.
123  typename GemmConfig_::ScalarC const,
124  // The tile has size (N / Iterations)xM in GEMM's terminology.
125  Shape<1,
126  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
127  GemmConfig_::OutputTile::kW>,
128  // The threads are distributed as warps x 32 (the traits may reorganize).
129  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
130  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
131  GemmConfig_::kScalarsPerLdgC>
132  GlobalLoadTileTraits;
133 
135  typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
137  typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
138 
140  typedef WmmaGemmGlobalIteratorCdTraits<
141  // The pointer is float.
142  typename GemmConfig_::ScalarD,
143  // The tile has size (N / Iterations)xM in GEMM's terminology.
144  Shape<1,
145  GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
146  GemmConfig_::OutputTile::kW>,
147  // The threads are distributed as warps x 32 (the traits may reorganize).
148  Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
149  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
150  GemmConfig_::kScalarsPerStgD>
151  GlobalStoreTileTraits;
152 
154  typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
156  typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
157 };
158 
160 
161 } // namespace gemm
162 } // namespace cutlass
163 
164 #endif // defined CUTLASS_USE_WMMA_API
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Definition: convert.h:33
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Implements the BLAS linear scaling function alpha*AB + beta*C.
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: tile_iterator.h:65
Definition: matrix_traits.h:357
Defines a type for restructuring a tile.
Defines tile iterator traits for loading thread block-level tile from global memory.
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
Definition: matrix_traits.h:159
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Defines iterator traits for efficiently loading and storing fragment to and from shared memory...