Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_global_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <typename Scalar_, typename Tile_, typename Threads_, int kAccessSize_>
38 struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
39  MatrixLayout::kColumnMajor,
40  Scalar_,
41  Tile_,
42  Threads_,
43  kAccessSize_> {
47  Scalar_,
48  Tile_,
49  Threads_,
50  kAccessSize_>
52 
55 
57  struct ThreadOffset {
59  Coord<4> operator()() const {
60  int thread_offset_h = threadIdx.x / Base::Threads::kW;
61  int thread_offset_w = threadIdx.x % Base::Threads::kW * Base::ThreadsDelta::kW;
62 
63  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
64  }
65  };
66 };
67 
69 
70 template <typename TileTraits_, typename Index_ = int>
71 struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
72  typename TileTraits_::Scalar,
73  IteratorAdvance::kH,
74  MemorySpace::kGlobal,
75  Index_> {
79  typedef TileTraits_ Traits;
81  typedef TileIteratorBase<Traits,
82  typename TileTraits_::Scalar,
85  Index_>
90  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
91 
93  typedef typename TileTraits_::Scalar Scalar;
95  typedef typename TileTraits_::Pointer Pointer;
97  typedef typename TileTraits_::Threads Threads;
99  typedef Index_ Index;
101  typedef typename TileTraits_::ThreadOffset ThreadOffset;
102 
104  struct Params {
115 
118  Pointer pointer, Index ld, Index n, Index epilogue_stride_w, Index epilogue_delta_w) {
119  // The pointer.
120  this->pointer = pointer;
121  // Setup the base stride. One "group of threads" per column.
122  stride_h = ld;
123  // Each thread output 1 column per iteration. .
124  inc_h = ld * TileTraits_::Threads::kH;
125  inc_advance = inc_h + epilogue_stride_w;
126 
127  predicate_offset = n;
128  predicate_inc_h = TileTraits_::Threads::kH;
129  predicate_inc_advance = predicate_inc_h + epilogue_delta_w;
130 
131  // It worked.
132  return 0;
133  }
134  };
135 
137 
139 
141  CUTLASS_DEVICE WmmaGemmGlobalIteratorCd() {}
142 
144  CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params,
145  const Coord<3>& bounds,
146  const Coord<3>& block,
147  int const pointer_offset = 0,
148  int const pred_offset = 0,
149  ThreadOffset thread_offset_func = ThreadOffset())
150 
151  : params(params) {
152  thread_offset = thread_offset_func();
153  // Each warp works on a different column of the tile.
154  int const h = thread_offset[1] + block[1];
155  // Each lane writes a different element.
156  int const w = thread_offset[2] + block[2];
157  // Setup the pointer.
158  this->params.pointer += ((h * params.stride_h + w) + pointer_offset);
159 
160  // Prepare the vector of predicates.
161  for (int i = 0; i < Base::Iterations::kW; ++i) {
162  predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
163  }
164  this->params.predicate_offset -= (h + pred_offset);
165  }
166 
168  CUTLASS_DEVICE void inc_c() {}
170  CUTLASS_DEVICE void inc_w() {}
172  CUTLASS_DEVICE void inc_h() {
175  }
177  CUTLASS_DEVICE void inc_d() {}
179  CUTLASS_DEVICE void inc_advance() {
182  }
183 
185  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
186  return predicates.at(w) && params.predicate_offset > 0;
187  }
188 
191  Pointer data() { return params.pointer; }
192 
194  Pointer const data() const { return params.pointer; }
195 
198 };
199 
201 
202 } // namespace gemm
203 } // namespace cutlass
TileTraits_::Threads Threads
The threads.
Definition: wmma_gemm_global_tile.h:97
Definition: convert.h:33
Defines iterators for efficiently loading and storing to global memory.
Definition: gemm_global_tile.h:70
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:356
CUTLASS_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: wmma_gemm_global_tile.h:177
Definition: load_store.h:43
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: wmma_gemm_global_tile.h:108
CUTLASS_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: wmma_gemm_global_tile.h:170
Index_ Index
The index.
Definition: wmma_gemm_global_tile.h:99
TileTraits_::Scalar Scalar
The scalar.
Definition: wmma_gemm_global_tile.h:93
Definition: tile_iterator.h:62
Definition: matrix_traits.h:43
Params params
Definition: wmma_gemm_global_tile.h:136
Index predicate_inc_h
The strides to increment the predicate offset.
Definition: wmma_gemm_global_tile.h:114
Pointer pointer
The pointer.
Definition: wmma_gemm_global_tile.h:106
CUTLASS_HOST_DEVICE Pointer const data() const
Definition: wmma_gemm_global_tile.h:194
CUTLASS_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: wmma_gemm_global_tile.h:179
The params.
Definition: wmma_gemm_global_tile.h:104
Index inc_h
The strides to increment the pointer.
Definition: wmma_gemm_global_tile.h:110
TileIteratorBase< Traits, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:86
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd()
Ctor.
Definition: wmma_gemm_global_tile.h:141
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: wmma_gemm_global_tile.h:112
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: wmma_gemm_global_tile.h:59
Definition: wmma_gemm_global_tile.h:71
Index predicate_inc_advance
Definition: wmma_gemm_global_tile.h:114
TileTraits_::Pointer Pointer
The pointer.
Definition: wmma_gemm_global_tile.h:95
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:54
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Coord< 4 > thread_offset
Definition: wmma_gemm_global_tile.h:138
Index inc_advance
Definition: wmma_gemm_global_tile.h:110
static MatrixLayout::Kind const kLayout
The layout.
Definition: wmma_gemm_global_tile.h:90
Definition: wmma_gemm_global_tile.h:38
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:102
TileTraits_::ThreadOffset ThreadOffset
The thread offset functor.
Definition: wmma_gemm_global_tile.h:101
Definition: matrix_traits.h:36
CUTLASS_HOST_DEVICE Pointer data()
Returns the raw pointer.
Definition: wmma_gemm_global_tile.h:191
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:364
Kind
Definition: matrix_traits.h:36
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:51
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const &params, const Coord< 3 > &bounds, const Coord< 3 > &block, int const pointer_offset=0, int const pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: wmma_gemm_global_tile.h:144
WmmaGemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: wmma_gemm_global_tile.h:77
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: wmma_gemm_global_tile.h:197
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > ImmediateOffsetStrides
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:88
Computes the thread offset in (H, W) based on thread ID.
Definition: wmma_gemm_global_tile.h:57
CUTLASS_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: wmma_gemm_global_tile.h:168
CUTLASS_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: wmma_gemm_global_tile.h:172
TileTraits_ Traits
The traits.
Definition: wmma_gemm_global_tile.h:79
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Test the predicate.
Definition: wmma_gemm_global_tile.h:185
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld, Index n, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: wmma_gemm_global_tile.h:117