Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #if !defined(__CUDACC_RTC__)
31 #include <cuda.h>
32 #endif
33 
34 #include <cutlass/coord.h>
35 #include <cutlass/util/platform.h>
36 
37 namespace cutlass {
38 namespace gemm {
39 
41 
42 template <typename Gemm_>
43 __global__ void gemm_kernel(typename Gemm_::Params params) {
44  // Declare shared memory.
45  __shared__ typename Gemm_::SharedStorage shared_storage;
46 
47  // Construct the GEMM object.
48  Gemm_ gemm(params, shared_storage);
49  // Run GEMM.
50  gemm.multiply_add();
51 }
52 
54 
55 template <typename Scalar_, typename Index_ = int>
56 struct GemmDesc {
58  Index_ m, n, k;
60  Scalar_ alpha, beta;
62  void const* d_a;
64  Index_ lda;
66  void const* d_b;
68  Index_ ldb;
70  void const* d_c;
72  Index_ ldc;
74  void* d_d;
76  Index_ ldd;
77 };
78 
80 
81 template <typename GemmTraits_>
82 struct Gemm {
86  typedef GemmTraits_ Traits;
88  typedef typename Traits::SharedStorage SharedStorage;
89 
91  typedef typename Traits::ScalarA ScalarA;
93  typedef typename Traits::ScalarB ScalarB;
95  typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
97  typedef typename Traits::Epilogue::ScalarC ScalarC;
99  typedef typename Traits::Epilogue::ScalarD ScalarD;
101  typedef typename Traits::Index Index;
102 
104  static int const kThreads = Traits::GemmConfig::kThreads;
105 
107  struct Params : public Traits::Params {
109  Index n,
110  Index k,
111  ScalarEpilogue alpha,
112  ScalarA const* d_a,
113  Index lda,
114  ScalarB const* d_b,
115  Index ldb,
116  ScalarEpilogue beta,
117  ScalarC const* d_c,
118  Index ldc,
119  ScalarD* d_d,
120  Index ldd) {
122  desc.m = m;
123  desc.n = n;
124  desc.k = k;
125  desc.alpha = alpha;
126  desc.beta = beta;
127  desc.d_a = reinterpret_cast<void const*>(d_a);
128  desc.lda = lda;
129  desc.d_b = reinterpret_cast<void const*>(d_b);
130  desc.ldb = ldb;
131  desc.d_c = reinterpret_cast<void const*>(d_c);
132  desc.ldc = ldc;
133  desc.d_d = reinterpret_cast<void*>(d_d);
134  desc.ldd = ldd;
135  return Traits::Params::initialize(desc);
136  }
137  };
138 
139 #if !defined(__CUDACC_RTC__)
140  static __host__ cudaError_t launch(Params const& params,
142  cudaStream_t stream = cudaStreamDefault) {
143  // Setup the grid.
144  dim3 grid;
145  grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
146  grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
147 
148  // The number of threads.
149  dim3 block;
150  block.x = kThreads;
151 
152  // Launch the kernel.
153  void const* params_ = reinterpret_cast<void const*>(&params);
154 
155  return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>),
156  grid,
157  block,
158  const_cast<void**>(&params_),
159  0,
160  stream);
161  }
162 
164  static __host__ cudaError_t launch(CUfunction kernel,
165  Params const& params,
166  CUstream stream = CU_STREAM_LEGACY) {
167  // Setup the grid.
168  dim3 grid;
169  grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
170  grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
171 
172  // The number of threads.
173  dim3 block;
174  block.x = kThreads;
175 
176  // Launch the kernel.
177  void* params_[] = {const_cast<void*>(reinterpret_cast<void const*>(&params))};
178 
179  // return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>), grid, block,
180  // const_cast<void**>(&params_), 0, stream);
181  CUresult result = cuLaunchKernel(
182  kernel, grid.x, grid.y, grid.z, block.x, block.y, block.z, 0, stream, params_, 0);
183 
184  if (result != CUDA_SUCCESS) {
185  return cudaErrorLaunchFailure;
186  }
187  return cudaSuccess;
188  }
189 
190 #endif
191 
193  CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
194  : params(params_), shared_storage(shared_storage_) {}
195 
197  CUTLASS_DEVICE void multiply_add() {
198  // Swizzle the IDs of the block (to enable better cache behavior).
199  typename Traits::BlockSwizzle block_swizzle;
200  dim3 block = block_swizzle.swizzle();
201 
202  // Scale the id.
203  block.x *= Traits::OutputTile::kW;
204  block.y *= Traits::OutputTile::kH;
205 
206  // We may want to use shared memory to clear the registers.
207  typedef typename Traits::ClearAccumulators ClearAccumulators;
208 
209  // The streams to read A/B from global memory to shared memory.
210  typename Traits::GlobalLoadStream global_stream(params, shared_storage, block);
211 
212  // Create the accumulator clear.
213  ClearAccumulators clear(shared_storage.main_loop.clear);
214 
216  typedef typename Traits::MultiplyAdd MultiplyAdd;
217 
218  // By how much we unroll the main loop.
219  Index const kUnroll = static_cast<Index>(MultiplyAdd::AccumulatorsPerWarp::kD);
220 
221  // If we do not have enough steps in the main loop, trigger the residue code.
222  if (params.k < kUnroll) {
223  global_stream.residue(params.k, true);
224  }
225 
226  // Fetch the fragments for A and B from global memory.
227  global_stream.copy();
228 
229  // Copy the elements to shared memory (after transformation if needed).
230  global_stream.commit();
231 
232  // Make sure the data is in shared memory.
233  Traits::shared_store_fence(false);
234 
235  // The unrolling steps for the main loop.
236  int const kUnrollingSteps =
237  MultiplyAdd::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
238 
239  // Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
240  static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps");
241 
242  // The stream of data from shared memory to fragments.
243  typename Traits::SharedLoadStream shared_load_stream(params, shared_storage);
244 
245  // Trigger the copy from shared memory for the 1st stream.
246  shared_load_stream.copy(0);
247 
248  // Allocate the accumulators.
249  typename MultiplyAdd::Accumulators accumulators;
250  // Clear the accumulators.
251  clear.clear(accumulators);
252 
253  // Enter the main loop and iterate.
254  typedef typename Traits::Index Index;
255  for (Index outer_k = params.k - kUnroll; outer_k > -kUnroll; outer_k -= kUnroll) {
256  // If that's the last "load iteration" update the predicates.
257  int const is_residue = outer_k <= kUnroll;
258  if (is_residue) {
259  global_stream.residue(outer_k);
260  }
261 
262  // Load data for the next iteration of the main loop.
263  global_stream.copy();
264 
266  for (int step = 0; step < kUnrollingSteps - 1; ++step) {
267  // Trigger the copy from shared memory for the next A/B values.
268  shared_load_stream.copy(step + 1);
269  // Make sure the values are available for the current iteration to do the multiply-add.
270  shared_load_stream.commit(step);
271 
272  // Do the math on the fragments of the current iteration.
273  MultiplyAdd multiply_add;
274  multiply_add.multiply_add(shared_load_stream.fragment_a(step),
275  shared_load_stream.fragment_b(step),
276  accumulators,
277  accumulators);
278  }
279 
280  // Make sure the data from shared memory has been entirely consumed.
281  Traits::shared_load_fence(true);
282 
283  // Commit the data in shared memory for A/B.
284  global_stream.commit();
285 
286  // Make sure the data is in shared memory.
287  Traits::shared_store_fence(true);
288 
289  // Move to the next stage for the load (if it makes sense).
290  shared_load_stream.inc_stage();
291  // Trigger the copy from shared memory for the next loop iteration.
292  shared_load_stream.copy(0);
293  // Make sure the values are available for the current iteration to do the multiply-add.
294  shared_load_stream.commit(kUnrollingSteps - 1);
295 
296  // Do the math on the fragments of the current iteration.
297  MultiplyAdd multiply_add;
298  multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
299  shared_load_stream.fragment_b(kUnrollingSteps - 1),
300  accumulators,
301  accumulators);
302  }
303 
304  // Epilogue.
305  typedef typename Traits::Epilogue Epilogue;
306  Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.m, params.n);
307  epilogue.epilogue(cutlass::make_Coord(0, block.y, block.x), accumulators);
308  }
309 
311  Params const& params;
314 };
315 
317 
318 } // namespace gemm
319 } // namespace cutlass
Definition: gemm.h:56
Definition: convert.h:33
SharedStorage & shared_storage
The shared storage.
Definition: gemm.h:313
Traits::Epilogue::ScalarD ScalarD
The scalar for D.
Definition: gemm.h:99
Scalar_ beta
Definition: gemm.h:60
Index_ k
Definition: gemm.h:58
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm.h:88
The params.
Definition: gemm.h:107
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Params const & params
The params.
Definition: gemm.h:311
Index_ m
The dimensions of the GEMM.
Definition: gemm.h:58
Traits::Epilogue::ScalarC ScalarC
The scalar for C.
Definition: gemm.h:97
Index_ ldb
The stride for B.
Definition: gemm.h:68
C++ features that may be otherwise unimplemented for CUDA device functions.
CUTLASS_DEVICE void multiply_add()
Do the GEMM.
Definition: gemm.h:197
GemmTraits_ Traits
The traits.
Definition: gemm.h:86
Traits::Epilogue::Scalar ScalarEpilogue
The scalar in the epilogue.
Definition: gemm.h:95
Index_ n
Definition: gemm.h:58
Traits::ScalarB ScalarB
The scalar for B.
Definition: gemm.h:93
Definition: clear_accumulators.h:38
void * d_d
The destination matrix D.
Definition: gemm.h:74
Definition: gemm.h:82
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:60
static __host__ cudaError_t launch(CUfunction kernel, Params const &params, CUstream stream=CU_STREAM_LEGACY)
Launch the kernel.
Definition: gemm.h:164
void const * d_a
The source matrix A.
Definition: gemm.h:62
__global__ void gemm_kernel(typename Gemm_::Params params)
Definition: gemm.h:43
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, ScalarEpilogue alpha, ScalarA const *d_a, Index lda, ScalarB const *d_b, Index ldb, ScalarEpilogue beta, ScalarC const *d_c, Index ldc, ScalarD *d_d, Index ldd)
Definition: gemm.h:108
Index_ lda
The stride for A.
Definition: gemm.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Gemm< GemmTraits_ > This_
This class.
Definition: gemm.h:84
Index_ ldc
The stride for C.
Definition: gemm.h:72
CUTLASS_DEVICE Gemm(Params const &params_, SharedStorage &shared_storage_)
Ctor.
Definition: gemm.h:193
#define static_assert(__e, __m)
Definition: platform.h:145
Index_ ldd
The stride for D.
Definition: gemm.h:76
Traits::ScalarA ScalarA
The scalar for A.
Definition: gemm.h:91
CUTLASS_DEVICE void clear(Fragment_ &fragment)
Clear the fragment.
Definition: clear_accumulators.h:47
static int const kThreads
The number of threads.
Definition: gemm.h:104
Scalar_ alpha
The alpha/beta scaling values.
Definition: gemm.h:60
void const * d_c
The source matrix C.
Definition: gemm.h:70
static __host__ cudaError_t launch(Params const &params, cudaStream_t stream=cudaStreamDefault)
Launch the kernel.
Definition: gemm.h:141
Traits::Index Index
The index.
Definition: gemm.h:101
void const * d_b
The source matrix B.
Definition: gemm.h:66