Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_shared_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cutlass/wmma_matrix.h>
32 #ifdef CUTLASS_USE_WMMA_API
33 
35 #include <cutlass/reshape_tile.h>
36 
37 namespace cutlass {
38 namespace gemm {
39 
40 template <class>
41 struct Debug {};
42 
44 
45 template <MatrixLayout::Kind kLayout_,
46  typename Scalar_,
47  typename Tile_,
48  typename Warps_,
49  int kWarpStride_,
50  typename Iterations_,
51  typename Delta_,
52  typename WmmaShape_>
53 struct WmmaGemmSharedLoadTileATraits {
55  static GemmOperand::Kind const kOperand = GemmOperand::kA;
57  static MatrixLayout::Kind const kLayout = kLayout_;
59  typedef Scalar_ Scalar;
61  typedef Scalar const* Pointer;
63  static int const kAccessSize = 1;
65  typedef Tile_ Tile;
67  typedef Warps_ Warps;
69  static int const kWarpStride = kWarpStride_;
71  typedef Iterations_ Iterations;
73  typedef Delta_ Delta;
75  typedef Delta_ ImmediateOffsetStrides;
77  typedef WmmaShape_ WmmaShape;
79  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
81  struct ThreadOffset {
83  Coord<4> operator()() const {
84  // The warp id.
85  int const warp = threadIdx.x / kWarpSize;
86  // The offset.
87  int const offset = warp % Warps::kW * kWarpStride;
88  return make_Coord(0, 0, offset, 0);
89  }
90  };
91 };
92 
94 
95 template <MatrixLayout::Kind kLayout_,
96  typename Scalar_,
97  typename Tile_,
98  typename Warps_,
99  int kWarpStride_,
100  typename Iterations_,
101  typename Delta_,
102  typename WmmaShape_>
103 struct WmmaGemmSharedLoadTileBTraits {
105  static GemmOperand::Kind const kOperand = GemmOperand::kB;
107  static MatrixLayout::Kind const kLayout = kLayout_;
109  typedef Scalar_ Scalar;
111  typedef Scalar const* Pointer;
113  static int const kAccessSize = 1;
115  typedef Tile_ Tile;
117  typedef Warps_ Warps;
119  static int const kWarpStride = kWarpStride_;
121  typedef Iterations_ Iterations;
123  typedef Delta_ Delta;
125  typedef Delta_ ImmediateOffsetStrides;
127  typedef WmmaShape_ WmmaShape;
129  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
131  struct ThreadOffset {
133  Coord<4> operator()() const {
134  // The warp id.
135  int const warp = threadIdx.x / kWarpSize;
136  // The offset.
137  int const offset = warp / Warps::kW * kWarpStride;
138  return make_Coord(0, 0, offset, 0);
139  }
140  };
141 };
142 
144 
145 template <MatrixLayout::Kind kLayout_,
146  typename Scalar_,
147  typename OutputTile_,
148  typename Warps_,
149  typename WmmaShape_,
150  int kSkew_ = 0>
151 struct WmmaGemmSharedStoreTileDTraits {
153  static GemmOperand::Kind const kOperand = GemmOperand::kC;
155  static MatrixLayout::Kind const kLayout = kLayout_;
157  typedef Scalar_ Scalar;
158  // The access size
159  static int const kAccessSize = 1;
161  typedef Scalar* Pointer;
163  typedef Warps_ Warps;
165  typedef WmmaShape_ WmmaShape;
167  static int const kSkew = kSkew_;
169  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
171  typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
173  typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
175  typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
177  typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
178 
180  struct ThreadOffset {
182  Coord<4> operator()() const {
183  // The warp id.
184  int const warp = threadIdx.x / kWarpSize;
185  // The starting column.
186  int const h = warp / Warps::kW * WmmaShape::kH;
187  // The w.
188  int const w = warp % Warps::kW * WmmaShape::kW;
189  // The offset.
190  int const offset = h * Tile::kW + w;
191  return make_Coord(0, 0, offset, 0);
192  }
193  };
194 };
195 
197 
198 template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_>
199 struct WmmaGemmSharedLoadTileDTraits {
201  typedef Scalar_ Scalar;
203  typedef Scalar const* Pointer;
205  static int const kAccessSize = kScalarsPerLds_;
207  typedef typename ReshapeTile<Tile_, kScalarsPerLds_>::Tile Tile;
209  typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
211  typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
213  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
214 
216  typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
218  typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_>
219  ImmediateOffsetStrides;
221  typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
222  Iterations;
223 
225  struct ThreadOffset {
227  Coord<4> operator()() const {
228  // The offset.
230  return make_Coord(0, 0, offset, 0);
231  }
232  };
233 };
234 
236 
237 } // namespace gemm
238 } // namespace cutlass
239 
240 #endif // defined CUTLASS_USE_WMMA_API
static CUTLASS_DEVICE int get()
Definition: shape.h:253
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:42
Definition: convert.h:33
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Definition: matrix_traits.h:43
Kind
Definition: load_store.h:40
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Definition: matrix_traits.h:43
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Kind
Definition: matrix_traits.h:36
Tile_ Tile
Definition: reshape_tile.h:43
Kind
Definition: matrix_traits.h:43
Definition: matrix_traits.h:43
Threads_ Threads
Definition: gemm_global_tile.h:54