Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_epilogue.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include <cutlass/convert.h>
33 #include <cutlass/coord.h>
34 #include <cutlass/fragment.h>
35 
36 namespace cutlass {
37 namespace gemm {
38 
40 
41 template <typename T>
42 CUTLASS_DEVICE bool is_zero(T x) {
43  return x == T(0);
44 }
45 
46 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
47 CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
48 #endif
49 
51 
52 template <typename GemmEpilogueTraits_>
53 struct GemmEpilogue {
55  typedef GemmEpilogueTraits_ Traits;
57  typedef typename Traits::Params Params;
59  typedef typename Traits::SharedStorage SharedStorage;
60 
62  typedef typename Traits::OutputTile OutputTile;
64  typedef typename Traits::Iterations Iterations;
66  typedef typename Traits::Accumulators Accumulators;
68  typedef typename Traits::Scalar Scalar;
70  typedef typename Traits::Functor Functor;
71 
73  static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
74 
76  typedef typename Traits::GlobalLoadIteratorC GlobalLoadIteratorC;
78  typedef typename Traits::GlobalTransformerC GlobalTransformerC;
80  typedef typename Traits::GlobalTransformerD GlobalTransformerD;
82  typedef typename Traits::GlobalStoreIteratorD GlobalStoreIteratorD;
84  typedef typename Traits::SharedStoreIteratorD SharedStoreIteratorD;
86  typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
88  typedef typename Traits::SharedLoadIteratorD SharedLoadIteratorD;
91 
93  typedef typename Traits::Index Index;
94 
96  typedef typename GlobalLoadIteratorC::Scalar ScalarC;
98  typedef typename GlobalStoreIteratorD::Scalar ScalarD;
99 
101  CUTLASS_DEVICE GemmEpilogue(Params const& params_,
102  SharedStorage& shared_storage_,
103  Index m_,
104  Index n_)
105  : params(params_), shared_storage(shared_storage_), m(m_), n(n_) {}
106 
108  CUTLASS_DEVICE void epilogue(Coord<3> const& block, Accumulators& accumulators) {
109  if (is_zero(params.functor.beta)) {
110  epilogue_with_or_without_beta<true>(block, accumulators);
111  } else {
112  epilogue_with_or_without_beta<false>(block, accumulators);
113  }
114  }
115 
116  template <bool kBetaIsZero_>
117  CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block,
118  Accumulators& accumulators) {
119 
120  Coord<3> const bounds = cutlass::make_Coord(0, n, m);
121 
122  // The functor.
123  Functor functor(params.functor);
124  // The C fragment.
125  typename GlobalLoadIteratorC::Fragment fragment_c;
126  // The transformed C fragment.
127  typename GlobalTransformerC::OutputFragment transformed_c;
128 
130  for (int h = 0; h < Iterations::kH; ++h) {
131  // Compute pointer and predicate offsets for C and D global iterators.
132  int const pointer_offset =
133  ((params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
134  params.iterator_d.inc_advance) *
135  Iterations::kW +
136  params.stride_h) *
137  h;
138  int const predicate_offset =
139  ((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
140  params.iterator_d.predicate_inc_advance) *
141  Iterations::kW +
142  Traits::Delta::kH) *
143  h;
144 
145  // The iterator to load the elements of the C matrix.
146  GlobalLoadIteratorC global_load_iterator(
147  params.iterator_c, bounds, block, pointer_offset, predicate_offset);
148  // The transformer for C.
149  GlobalTransformerC transformer_c;
150  // The transformer for D.
151  GlobalTransformerD transformer_d;
152  // The iterator to store into the D matrix.
153  GlobalStoreIteratorD global_store_iterator(
154  params.iterator_d, bounds, block, pointer_offset, predicate_offset);
155 
157  for (int w = 0; w < Iterations::kW; ++w) {
158  // Load the C matrix into fragment.
159  if (!kBetaIsZero_) {
160  iterator_load(global_load_iterator, fragment_c);
161  }
162 
163  // Make sure we can write to shared memory.
165 
166  // Copy the accumulators to shared memory.
167  int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
168 
169  SharedStoreTransformerD shared_store_transformer;
170  typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
171  shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
172 
173  SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
174  shared_storage.shared_stream.store);
175  shared_iterator_store(shared_store_iterator, shared_store_transformed_d);
176 
177  // Make sure the data is in shared memory.
179 
180  // Copy the accumulators back to registers from shared memory.
181  SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
182  shared_storage.shared_stream.load);
183  typename SharedLoadIteratorD::Fragment fetched_d;
184  shared_iterator_load(shared_load_iterator, fetched_d);
185 
186  // Do the math.
187  typename GlobalTransformerD::InputFragment fragment_d;
188 
189  if (kBetaIsZero_) {
190  functor.evaluate(fetched_d, fragment_d);
191  } else {
192  // Transform C fragment.
193  transformer_c.transform(fragment_c, transformed_c);
194  // Do the math.
195  functor.evaluate(fetched_d, transformed_c, fragment_d);
196  }
197 
198  // Transform D fragment.
199  typename GlobalTransformerD::OutputFragment transformed_d;
200  transformer_d.transform(fragment_d, transformed_d);
201 
202  // Copy the results to global memory.
203  iterator_store(global_store_iterator, transformed_d);
204  }
205  }
206  }
207 
209  CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
210 
212  CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
213 
215  Params const& params;
220 };
221 
223 
224 } // namespace gemm
225 } // namespace cutlass
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:98
Traits::SharedStoreIteratorD SharedStoreIteratorD
The iterator to store D in shared memory.
Definition: gemm_epilogue.h:84
Definition: convert.h:33
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from a shared memory input iterator.
Definition: iterator_access.h:75
Traits::Params Params
The params.
Definition: gemm_epilogue.h:57
Definition: gemm_epilogue.h:53
CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord< 3 > const &block, Accumulators &accumulators)
Definition: gemm_epilogue.h:117
CUTLASS_DEVICE GemmEpilogue(Params const &params_, SharedStorage &shared_storage_, Index m_, Index n_)
Ctor.
Definition: gemm_epilogue.h:101
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Definition: convert.h:69
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm_epilogue.h:59
Traits::GlobalTransformerD GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue.h:80
Traits::OutputTile OutputTile
The output tile.
Definition: gemm_epilogue.h:62
Traits::Accumulators Accumulators
The accumulators.
Definition: gemm_epilogue.h:66
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:60
CUTLASS_DEVICE void shared_load_fence()
The memory fence for shared loads.
Definition: gemm_epilogue.h:209
SharedStorage & shared_storage
The shared storage.
Definition: gemm_epilogue.h:217
GemmEpilogueTraits_ Traits
The traits class.
Definition: gemm_epilogue.h:55
CUTLASS_DEVICE bool is_zero(T x)
Definition: gemm_epilogue.h:42
Params const & params
The params.
Definition: gemm_epilogue.h:215
Traits::SharedLoadIteratorD SharedLoadIteratorD
The iterator to load D in shared memory.
Definition: gemm_epilogue.h:88
Traits::Index Index
The index.
Definition: gemm_epilogue.h:93
#define static_assert(__e, __m)
Definition: platform.h:145
Traits::SharedStoreTransformerD SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue.h:86
CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment)
Stores a fragment to a shared memory output iterator.
Definition: iterator_access.h:228
Traits::GlobalStoreIteratorD GlobalStoreIteratorD
The iterator for D in global memory.
Definition: gemm_epilogue.h:82
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment)
Stores a fragment to an output iterator.
Definition: iterator_access.h:193
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:96
Index n
Definition: gemm_epilogue.h:219
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from an input iterator.
Definition: iterator_access.h:41
Traits::Functor Functor
The functor in charge of the math.
Definition: gemm_epilogue.h:70
Traits::Iterations Iterations
The number of iterations.
Definition: gemm_epilogue.h:64
CUTLASS_DEVICE void epilogue(Coord< 3 > const &block, Accumulators &accumulators)
Execute the epilogue.
Definition: gemm_epilogue.h:108
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
Copy< typename SharedLoadIteratorD::Fragment > SharedLoadTransformerD
The shared load transformer for D.
Definition: gemm_epilogue.h:90
Traits::Scalar Scalar
The scalar.
Definition: gemm_epilogue.h:68
Defines conversion operations among Fragments of different base type.
Index m
The dimensions of the GEMM.
Definition: gemm_epilogue.h:219
CUTLASS_DEVICE void shared_store_fence()
The memory fence for shared stores.
Definition: gemm_epilogue.h:212
Traits::GlobalTransformerC GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue.h:78
Traits::GlobalLoadIteratorC GlobalLoadIteratorC
We do not support 3D or 4D shapes.
Definition: gemm_epilogue.h:73