Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_global_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/coord.h>
31 #include <cutlass/util/platform.h>
32 
34 #include <cutlass/matrix_traits.h>
36 #include <cutlass/reshape_tile.h>
37 #include <cutlass/tile_iterator.h>
38 
39 namespace cutlass {
40 namespace gemm {
41 
43 
44 // The following functor reshapes a tile of threads to match a tile of data. The idea is that when
45 // the user wants to build the iterator traits, he/she may want to specify the tile independently
46 // from the number of scalars loaded/stored per instruction. For example, in the row-major version
47 // with a tile of size 128x8 - the user may want to that the iterator works with 32x8 threads if
48 // each thread loads 1 scalar per LDG. If the user changes to 4 scalars per LDG, then the tile of
49 // threads has to change. The code below detects that and correct the code automatically - it is
50 // a helper when the user does not specify the right configuration.
51 
52 template <typename Tile_, typename Threads_, bool = (Tile_::kW < Threads_::kW)>
53 struct ReshapeThreads {
54  typedef Threads_ Threads;
55 };
56 
57 template <typename Tile_, typename Threads_>
59  typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1> Threads;
60 };
61 
63 
64 template <GemmOperand::Kind kOperand_,
65  MatrixLayout::Kind kLayout_,
66  typename Scalar_,
67  typename Tile_,
68  typename Threads_,
69  int kAccessSize_>
72  static GemmOperand::Kind const kOperand = kOperand_;
74  static MatrixLayout::Kind const kLayout = kLayout_;
78  typedef Scalar_* Pointer;
80  static int const kAccessSize = kAccessSize_;
83 
90 
96  typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kAccessSize>
98 
100 
102  struct ThreadOffset {
105  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
106  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
107 
108  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
109  }
110  };
111 };
112 
114 
115 template <typename Scalar_, typename Tile_, typename Threads_, int kStrideH_, int kAccessSize_>
116 struct GemmGlobalTileCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
117  MatrixLayout::kColumnMajor,
118  Scalar_,
119  Tile_,
120  Threads_,
121  kAccessSize_> {
125  Scalar_,
126  Tile_,
127  Threads_,
128  kAccessSize_>
130 
132  static int const kStrideH = kStrideH_;
135 
136  typedef typename Base::Iterations Iterations;
137 
138  typedef typename Base::Threads Threads;
139 
141 
143 
145  struct ThreadOffset {
148  int thread_offset_h = threadIdx.x / Threads::kW * kStrideH * Iterations::kH;
149  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
150 
151  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
152  }
153  };
154 };
155 
157 
158 template <typename TileTraits_, typename Index_ = int>
160  : public TileLoadIterator<TileTraits_,
161  typename TileTraits_::Scalar,
162  TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
163  : IteratorAdvance::kW,
164  MemorySpace::kGlobal,
165  Index_> {
168 
169  typedef TileLoadIterator<TileTraits_,
170  typename TileTraits_::Scalar,
171  TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
174  Index_>
177  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
179  typedef typename Base::Fragment Fragment;
181  typedef typename TileTraits_::Scalar Scalar;
183  typedef typename TileTraits_::Threads Threads;
185  typedef Index_ Index;
187  typedef typename TileTraits_::ThreadOffset ThreadOffset;
190 
192 
194  typedef typename Base::Params BaseParams;
195 
196  struct Params : public BaseParams {
199  Index inc_d = 0;
200  Index inc_advance = 0;
201  // Move by some columns for each iteration in the H dimension.
202  Index inc_h = Base::Delta::kH * stride_h;
203 
204  // Move by some more columns in the number of iterations if the D dimension is > 1.
205  if (Base::Delta::kD > 0) {
206  inc_d = Base::Delta::kD * stride_h - (Base::Iterations::kH - 1) * inc_h;
207  }
208 
209  // Move to the beginning of the next iteration.
210  if (kAdvance == IteratorAdvance::kH && Base::Delta::kD > 0) {
211  inc_advance = inc_d;
212  } else if (kAdvance == IteratorAdvance::kH) {
213  inc_advance = inc_h;
214  } else if (Base::Delta::kD > 0) {
215  inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
216  (Base::Iterations::kH - 1) * inc_h -
217  (Base::Iterations::kD - 1) * Base::Delta::kD * stride_h;
218  } else {
219  inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
220  (Base::Iterations::kH - 1) * inc_h;
221  }
222 
224  return 0;
225  }
226  };
227 
232 
233  CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) {
234  // Setup the masks to control loads.
235  predicates.fill(0);
236 
237  int bounds_h, bounds_w;
238  if (kAdvance == IteratorAdvance::kH) {
239  bounds_w = bounds[2] - block[2];
240  bounds_h = bounds[1];
241 
242  } else {
243  bounds_w = bounds[1];
244  bounds_h = bounds[2] - block[1];
245  }
246 
247  // Fill in the bits of the predicate vector.
248  for (int d = 0; d < Base::Iterations::kD; ++d) {
249  for (int h = 0; h < Base::Iterations::kH; ++h) {
250  for (int w = 0; w < Base::Iterations::kW; ++w) {
251  for (int c = 0; c < Base::Iterations::kC; ++c) {
252  bool flag = w * Base::Delta::kW < bounds_w;
253  if (kAdvance == IteratorAdvance::kH) {
254  flag = flag && (h * Base::Delta::kH + d * Base::Delta::kD) < bounds_h;
255  } else {
256  flag = flag && (h * Base::Delta::kH) < bounds_h;
257  }
258  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
259  predicates.set(bit, flag);
260  }
261  }
262  }
263  }
264  }
265 
267  CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
268  const Coord<3>& bounds,
269  const Coord<3>& block,
270  ThreadOffset thread_offset_func = ThreadOffset())
271  : params(_params) {
272  thread_offset = thread_offset_func();
273  // The column.
274  Index block_h = thread_offset[1];
275  // The contiguous dimension.
276  Index block_w = thread_offset[2];
277 
278  // Add the blocks indices.
279  if (kAdvance == IteratorAdvance::kH) {
280  block_h += block[1];
281  block_w += block[2];
282 
283  } else {
284  block_h += block[2];
285  block_w += block[1];
286  }
287 
288  // Setup the pointer.
289  params.pointer += (block_h * params.stride_h + block_w);
290 
291  // Initialize predicates
292  initialize_predicates(bounds, make_Coord(0, block_h, block_w));
293  }
294 
296  CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
298  CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
300  CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
301 
304  Scalar const* data() const { return params.pointer; }
305 
307  CUTLASS_DEVICE void residue(Index k) {
308  // The coordinates of the thread.
309  Index block_h = thread_offset[1];
310  // The contiguous dimension.
311  Index block_w = thread_offset[2];
312 
313  // Update the predicate vector.
314  for (int d = 0; d < Base::Iterations::kD; ++d) {
315  for (int h = 0; h < Base::Iterations::kH; ++h) {
316  for (int w = 0; w < Base::Iterations::kW; ++w) {
317  for (int c = 0; c < Base::Iterations::kC; ++c) {
318  Index offset = 0;
319  if (kAdvance == IteratorAdvance::kH) {
320  offset += block_h + h * Base::Delta::kH + d * Base::Delta::kD;
321  } else {
322  offset += block_w + w * Base::Delta::kW;
323  }
324 
325  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
326  if (offset >= k) {
327  predicates.set(bit, false);
328  }
329  }
330  }
331  }
332  }
333  }
334 
336  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
337  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
338  return predicates[bit];
339  }
340 
343 };
344 
346 
347 template <typename TileTraits_, typename Index_ = int>
348 struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
349  typename TileTraits_::Scalar,
350  IteratorAdvance::kH,
351  MemorySpace::kGlobal,
352  Index_> {
356  typedef TileIteratorBase<TileTraits_,
357  typename TileTraits_::Scalar,
360  Index_>
362 
364  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
365 
367  typedef typename TileTraits_::Scalar Scalar;
369  typedef typename TileTraits_::Pointer Pointer;
371  typedef typename TileTraits_::Threads Threads;
373  typedef Index_ Index;
375  typedef typename TileTraits_::ThreadOffset ThreadOffset;
376 
378  struct Params {
389 
392  Pointer pointer, Index ld, Index bound, Index epilogue_stride_w, Index epilogue_delta_w) {
393  // The pointer.
394  this->pointer = pointer;
395  // Each column of the matrix.
396  stride_h = TileTraits_::ThreadsDelta::kH * ld;
397  // Each thread output 1 column per iteration. The stride between columns is given by the
398  // number of scalars that are loaded per LDS for B.
399  inc_h = ld * TileTraits_::kStrideH;
400  inc_advance =
401  (ld - ld * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
402 
403  predicate_offset = bound;
404  predicate_inc_h = TileTraits_::kStrideH;
406  -((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
407 
408  return 0;
409  }
410  };
411 
415 
417  CUTLASS_DEVICE GemmGlobalIteratorCd() {}
418 
420  CUTLASS_DEVICE GemmGlobalIteratorCd(Params const& params,
421  const Coord<3>& bounds,
422  const Coord<3>& block,
423  int offset = 0,
424  int pred_offset = 0,
425  ThreadOffset thread_offset_func = ThreadOffset())
426  : params(params) {
427  thread_offset = thread_offset_func();
428  // Each warp works on a different column of the tile.
429  int const h = thread_offset[1] + block[1];
430  // Each lane writes a different element.
431  int const w = thread_offset[2] + block[2];
432  // Setup the pointer.
433  this->params.pointer += ((h * params.stride_h + w) + offset);
434 
435  // Prepare the vector of predicates.
436  for (int i = 0; i < Base::Iterations::kW; ++i) {
437  predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
438  }
439  this->params.predicate_offset -= (h + pred_offset);
440  }
441 
443  CUTLASS_DEVICE void inc_c() {}
445  CUTLASS_DEVICE void inc_w() {}
447  CUTLASS_DEVICE void inc_h() {
450  }
452  CUTLASS_DEVICE void inc_d() {}
454  CUTLASS_DEVICE void inc_advance() {
457  }
458 
460  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
461  return predicates.at(w) && params.predicate_offset > 0;
462  }
463 
466  Pointer data() { return params.pointer; }
467 
469  Pointer const data() const { return params.pointer; }
470 
473 };
474 
476 
477 } // namespace gemm
478 } // namespace cutlass
Definition: gemm_global_tile.h:116
Shape< 0, Threads::kH, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:92
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:384
CUTLASS_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:452
Definition: convert.h:33
cutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount > PredicateVector
Definition: gemm_global_tile.h:191
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:177
T type
Definition: platform.h:369
Base::Params BaseParams
Iterator parameters type.
Definition: gemm_global_tile.h:194
Shape< 1, Tile::kH/Threads::kH, Tile::kW/Threads::kW, Tile::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_global_tile.h:97
Index_ Index
The index.
Definition: gemm_global_tile.h:373
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
GemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:354
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:364
Definition: gemm_global_tile.h:70
Scalar_ * Pointer
The pointer.
Definition: gemm_global_tile.h:78
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:356
Definition: load_store.h:43
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
GemmMultiplicandTraits< Tile, kOperand, kLayout > MultiplicandTraits
Definition: gemm_global_tile.h:99
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_global_tile.h:82
TileIteratorBase< TileTraits_, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:361
Shape< 1, 1, Tile::kC > ThreadsDelta
The relative offset between two elements in the H/W dimension in adjacent threads.
Definition: gemm_global_tile.h:89
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:134
Index predicate_inc_h
Definition: gemm_global_tile.h:386
static CUTLASS_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:166
CUTLASS_HOST_DEVICE Pointer const data() const
Definition: gemm_global_tile.h:469
CUTLASS_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &block)
Definition: gemm_global_tile.h:233
Definition: tile_iterator.h:62
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:331
TileLoadIterator< TileTraits_, typename TileTraits_::Scalar, TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:175
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: gemm_global_tile.h:336
Definition: gemm_global_tile.h:196
Definition: matrix_traits.h:43
C++ features that may be otherwise unimplemented for CUDA device functions.
Definition: gemm_global_tile.h:159
CUTLASS_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:454
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: gemm_global_tile.h:129
Kind
Definition: load_store.h:40
Index stride_h
Definition: tile_iterator.h:172
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: gemm_global_tile.h:189
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:183
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_h)
Initializes params to load a strip-mined tile, given pointer and stride_h.
Definition: gemm_global_tile.h:198
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:425
static int const kStrideH
The stride in the H dimension.
Definition: gemm_global_tile.h:132
static int const kH
The height of the cube.
Definition: shape.h:68
Shape< Threads_::kD, Threads_::kH *Threads_::kW/Tile_::kW, Tile_::kW, 1 > Threads
Definition: gemm_global_tile.h:59
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:386
static GemmOperand::Kind const kOperand
Identity of the operand.
Definition: gemm_global_tile.h:72
Index inc_h
Definition: tile_iterator.h:176
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:302
PredicateVector predicates
The predicates.
Definition: gemm_global_tile.h:342
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_global_tile.h:76
CUTLASS_HOST_DEVICE Scalar const * data() const
Returns the current pointer.
Definition: gemm_global_tile.h:304
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Base::Fragment Fragment
Fragment type loaded by the iterator.
Definition: gemm_global_tile.h:179
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:371
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:147
CUTLASS_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:447
CUTLASS_DEVICE GemmGlobalIteratorCd(Params const &params, const Coord< 3 > &bounds, const Coord< 3 > &block, int offset=0, int pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:420
Definition: gemm_operand.h:67
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:102
Index inc_advance
Definition: tile_iterator.h:179
CUTLASS_DEVICE void residue(Index k)
That&#39;s the residue! Update the predicates.
Definition: gemm_global_tile.h:307
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:343
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:267
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld, Index bound, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: gemm_global_tile.h:391
CUTLASS_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: gemm_global_tile.h:443
CUTLASS_HOST_DEVICE Pointer data()
Returns the raw pointer.
Definition: gemm_global_tile.h:466
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:390
Base::Threads Threads
Definition: gemm_global_tile.h:138
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:382
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:104
Shape< 0, 0, Threads::kW *ThreadsDelta::kW, kAccessSize > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: gemm_global_tile.h:94
Statically sized array of bits implementing.
Definition: predicate_vector.h:104
CUTLASS_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:296
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:375
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Base::ImmediateOffsetStrides ImmediateOffsetStrides
Definition: gemm_global_tile.h:142
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:367
Index inc_h
Definition: gemm_global_tile.h:384
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: gemm_global_tile.h:472
CUTLASS_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:298
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:380
GemmGlobalIteratorAb< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:167
ReshapeTile< Tile_, kAccessSize_ >::Tile Tile
The tile shape.
Definition: gemm_global_tile.h:85
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:364
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:102
CUTLASS_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:445
Params params
Definition: gemm_global_tile.h:412
Definition: gemm_global_tile.h:348
Definition: matrix_traits.h:36
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:414
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:187
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:364
Parameters.
Definition: tile_iterator.h:388
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:145
Kind
Definition: matrix_traits.h:36
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_global_tile.h:80
Tile_ Tile
Definition: reshape_tile.h:43
Definition: tile_iterator.h:62
Base::Iterations Iterations
Definition: gemm_global_tile.h:136
Index_ Index
The index.
Definition: gemm_global_tile.h:185
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:369
Kind
Definition: matrix_traits.h:43
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:181
Threads_ Threads
Definition: gemm_global_tile.h:54
ReshapeThreads< Tile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:87
CUTLASS_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:300
CUTLASS_DEVICE GemmGlobalIteratorCd()
Ctor.
Definition: gemm_global_tile.h:417
Params params
The parameters.
Definition: gemm_global_tile.h:231
Defines properties of matrices used to denote layout and operands to GEMM kernels.
The params.
Definition: gemm_global_tile.h:378
Base::ThreadsDelta ThreadsDelta
Definition: gemm_global_tile.h:140
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Test the validity of the iterator.
Definition: gemm_global_tile.h:460
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:229
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:388
Index inc_d
Definition: tile_iterator.h:175
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:74