Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_global_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/coord.h"
31 #include "cutlass/util/platform.h"
32 
34 #include "cutlass/matrix_traits.h"
36 #include "cutlass/reshape_tile.h"
37 #include "cutlass/tile_iterator.h"
38 
39 namespace cutlass {
40 namespace gemm {
41 
43 
44 // The following functor reshapes a tile of threads to match a tile of data. The idea is that when
45 // the user wants to build the iterator traits, he/she may want to specify the tile independently
46 // from the number of scalars loaded/stored per instruction. For example, in the row-major version
47 // with a tile of size 128x8 - the user may want to that the iterator works with 32x8 threads if
48 // each thread loads 1 scalar per LDG. If the user changes to 4 scalars per LDG, then the tile of
49 // threads has to change. The code below detects that and correct the code automatically - it is
50 // a helper when the user does not specify the right configuration.
51 
52 template <typename Tile_, typename Threads_, bool = (Tile_::kW < Threads_::kW)>
53 struct ReshapeThreads {
54  typedef Threads_ Threads;
55 };
56 
57 template <typename Tile_, typename Threads_>
59  typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1> Threads;
60 };
61 
63 
64 template <GemmOperand::Kind kOperand_,
65  MatrixLayout::Kind kLayout_,
66  typename Scalar_,
67  typename Tile_,
68  typename Threads_,
69  int kAccessSize_>
72  static GemmOperand::Kind const kOperand = kOperand_;
74  static MatrixLayout::Kind const kLayout = kLayout_;
78  typedef Scalar_* Pointer;
80  static int const kAccessSize = kAccessSize_;
84  typedef Tile_ Tile;
93 
97  typedef Shape<1,
98  VectorizedTile::kH / Threads::kH,
99  VectorizedTile::kW / Threads::kW,
100  VectorizedTile::kC / kAccessSize>
102 
104 
106  struct ThreadOffset {
109  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
110  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
111 
112  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
113  }
114  };
115 };
116 
118 
119 template <typename Scalar_, typename Tile_, typename Threads_, int kStrideH_, int kAccessSize_>
120 struct GemmGlobalTileCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
121  MatrixLayout::kColumnMajor,
122  Scalar_,
123  Tile_,
124  Threads_,
125  kAccessSize_> {
129  Scalar_,
130  Tile_,
131  Threads_,
132  kAccessSize_>
134 
136  static int const kStrideH = kStrideH_;
139 
140  typedef typename Base::Iterations Iterations;
141 
142  typedef typename Base::Threads Threads;
143 
145 
147 
149  struct ThreadOffset {
152  int thread_offset_h = threadIdx.x / Threads::kW * kStrideH * Iterations::kH;
153  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
154 
155  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
156  }
157  };
158 };
159 
161 
162 template <typename TileTraits_, typename Index_ = int>
164  : public TileLoadIterator<TileTraits_,
165  typename TileTraits_::Scalar,
166  TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
167  : IteratorAdvance::kW,
168  MemorySpace::kGlobal,
169  Index_> {
172  typedef TileLoadIterator<TileTraits_,
173  typename TileTraits_::Scalar,
174  TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
177  Index_>
180  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
182  typedef typename TileTraits_::Tile Tile;
184  typedef typename Base::Fragment Fragment;
186  typedef typename TileTraits_::Scalar Scalar;
188  typedef typename TileTraits_::Threads Threads;
190  typedef Index_ Index;
192  typedef typename TileTraits_::ThreadOffset ThreadOffset;
195 
197 
199  typedef typename Base::Params BaseParams;
200 
201  struct Params : public BaseParams {
204  long long stride_d,
205  Index stride_h) {
206  Index inc_d = 0;
207  Index inc_advance = 0;
208  // Move by some columns for each iteration in the H dimension.
209  Index inc_h = Base::Delta::kH * stride_h;
210 
211  // Move by some more columns in the number of iterations if the D dimension is > 1.
212  if (Base::Delta::kD > 0) {
213  inc_d = Base::Delta::kD * stride_h - (Base::Iterations::kH - 1) * inc_h;
214  }
215 
216  // Move to the beginning of the next iteration.
217  if (kAdvance == IteratorAdvance::kH && Base::Delta::kD > 0) {
218  inc_advance = inc_d;
219  } else if (kAdvance == IteratorAdvance::kH) {
220  inc_advance = inc_h;
221  } else if (Base::Delta::kD > 0) {
222  inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
223  (Base::Iterations::kH - 1) * inc_h -
224  (Base::Iterations::kD - 1) * Base::Delta::kD * stride_h;
225  } else {
226  inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
227  (Base::Iterations::kH - 1) * inc_h;
228  }
229 
231  ptr, stride_d, stride_h, 1, inc_d, inc_h, 0, inc_advance);
232  return 0;
233  }
234  };
235 
242 
243  CUTLASS_HOST_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block_offset) {
244  // Setup the masks to control loads.
245  predicates.fill(0);
246 
247  // Fill in the bits of the predicate vector.
248  for (int d = 0; d < Base::Iterations::kD; ++d) {
249  for (int h = 0; h < Base::Iterations::kH; ++h) {
250  for (int w = 0; w < Base::Iterations::kW; ++w) {
251  for (int c = 0; c < Base::Iterations::kC; ++c) {
252  bool flag = w * Base::Delta::kW + thread_offset[2] + block_offset[2] < bounds[2];
253  if (kAdvance == IteratorAdvance::kH) {
254  flag =
255  flag &&
256  (h * Base::Delta::kH + d * Base::Delta::kD) + thread_offset[1] + block_offset[1] <
257  bounds[1];
258  } else {
259  flag = flag && (h * Base::Delta::kH) + thread_offset[1] + block_offset[1] < bounds[1];
260  }
261  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
262  predicates.set(bit, flag);
263  }
264  }
265  }
266  }
267  }
268 
271  const Coord<3>& bounds,
272  const Coord<3>& threadblock_offset,
273  ThreadOffset thread_offset_func = ThreadOffset())
274  : params(_params) {
275  thread_offset = thread_offset_func();
276  // Setup the pointer.
277  params.pointer += ((threadblock_offset[1] + thread_offset[1]) * params.stride_h +
278  (threadblock_offset[2] + thread_offset[2]));
279 
280  }
281 
290 
293  typename Base::AccessType& value, int d, int h, int w, int c) const {
294  int const offset =
296  Load<Scalar,
300  typename Base::FragmentElement,
301  Base::Tile::kW,
302  Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
303  }
304 
307  // The coordinates of the thread.
308  Index block_h = thread_offset[1];
309  // The contiguous dimension.
310  Index block_w = thread_offset[2];
311 
312  // Update the predicate vector.
313  for (int d = 0; d < Base::Iterations::kD; ++d) {
314  for (int h = 0; h < Base::Iterations::kH; ++h) {
315  for (int w = 0; w < Base::Iterations::kW; ++w) {
316  for (int c = 0; c < Base::Iterations::kC; ++c) {
317  Index offset = 0;
318  if (kAdvance == IteratorAdvance::kH) {
319  offset += block_h + h * Base::Delta::kH + d * Base::Delta::kD;
320  } else {
321  offset += block_w + w * Base::Delta::kW;
322  }
323 
324  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
325  if (offset >= k) {
326  predicates.set(bit, false);
327  }
328  }
329  }
330  }
331  }
332  }
333 
335  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
336  int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
337  return predicates[bit];
338  }
339 
342 
343  long long _offset = offset.template dot<long long>(
345  );
346 
347  params.pointer += _offset;
348  return *this;
349  }
350 
352 
354  Index stride = params.stride_h;
355  if (kAdvance == IteratorAdvance::kW) {
356  stride = params.stride_w;
357  }
358  return stride;
359  }
360 
361  template <typename Fragment>
363  typename Base::FragmentIterator frag_iterator(fragment);
364  for (int d = 0; d < Base::Iterations::kD; ++d) {
365  for (int h = 0; h < Base::Iterations::kH; ++h) {
366  for (int w = 0; w < Base::Iterations::kW; ++w) {
367  for (int c = 0; c < Base::Iterations::kC; ++c) {
368  if (valid(d, h, w, c)) {
369  load_element(
370  reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
371  d,
372  h,
373  w,
374  c);
375  }
376  }
377  if (w < Base::Iterations::kW - 1) {
378  inc_w();
379  }
380  }
381  if (h < Base::Iterations::kH - 1) {
382  inc_h();
383  }
384  }
385  if (d < Base::Iterations::kD - 1) {
386  inc_d();
387  }
388  }
389  inc_advance();
390  }
391 };
392 
394 
395 template <typename TileTraits_, typename Index_ = int>
396 struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
397  typename TileTraits_::Scalar,
398  IteratorAdvance::kH,
399  MemorySpace::kGlobal,
400  Index_> {
404  typedef TileIteratorBase<TileTraits_,
405  typename TileTraits_::Scalar,
408  Index_>
410 
412  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
413 
415  typedef typename TileTraits_::Scalar Scalar;
417  typedef typename TileTraits_::Pointer Pointer;
419  typedef typename TileTraits_::Threads Threads;
421  typedef Index_ Index;
423  typedef typename TileTraits_::ThreadOffset ThreadOffset;
424 
426  struct Params {
430  long long stride_d;
439 
442  long long batch_stride,
443  Index ldm,
444  Index bound,
445  Index epilogue_stride_w,
446  Index epilogue_delta_w) {
447  // The pointer.
448  this->pointer = pointer;
449  // Stride per batch
450  stride_d = batch_stride;
451  // Each column of the matrix.
452  stride_h = TileTraits_::ThreadsDelta::kH * ldm;
453  // Each thread output 1 column per iteration. The stride between columns is given by the
454  // number of scalars that are loaded per LDS for B.
455  inc_h = ldm * TileTraits_::kStrideH;
456  inc_advance =
457  (ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
458 
459  predicate_offset = bound;
460  predicate_inc_h = TileTraits_::kStrideH;
462  -((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
463 
464  return 0;
465  }
466  };
467 
474 
477  const Coord<3>& bounds,
478  const Coord<3>& block_offset,
479  ThreadOffset thread_offset_func = ThreadOffset())
480  : params(_params) {
481  thread_offset = thread_offset_func();
482  // Prepare the vector of predicates.
483  for (int i = 0; i < Base::Iterations::kW; ++i) {
484  predicates.set(i, thread_offset[2] + i * Base::Delta::kW < bounds[2]);
485  }
486  }
487 
490  const Coord<3>& bounds,
491  const Coord<3>& block,
492  int offset = 0,
493  int pred_offset = 0,
494  ThreadOffset thread_offset_func = ThreadOffset())
495  : params(_params) {
496  thread_offset = thread_offset_func();
497  // Each warp works on a different column of the tile.
498  int const h = thread_offset[1] + block[1];
499  // Each lane writes a different element.
500  int const w = thread_offset[2] + block[2];
501  // Setup the pointer.
502  params.pointer += ((h * params.stride_h + w) + offset);
503 
504  // Prepare the vector of predicates.
505  for (int i = 0; i < Base::Iterations::kW; ++i) {
506  predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
507  }
508  params.predicate_offset -= (h + pred_offset);
509  }
510 
519  }
526  }
527 
530  long long _offset = offset.template dot<long long>(
532  );
533  params.pointer += _offset;
534  return *this;
535  }
536 
539  typename Base::AccessType& value, int d, int h, int w, int c) const {
540  int const offset =
542  Load<Scalar,
546  typename Base::FragmentElement,
547  Base::Tile::kW,
548  Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
549  }
550 
553  typename Base::AccessType const& value, int d, int h, int w, int c) {
554  int const offset =
556  Store<Scalar,
560  typename Base::FragmentElement,
561  Base::Tile::kW,
562  Base::kAccessSize * sizeof(Scalar)>::store(value, params.pointer, offset);
563  }
564 
566  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
567  return predicates.at(w) && params.predicate_offset > 0;
568  }
569 
572 
574  template <typename Fragment>
576  typename Base::FragmentIterator frag_iterator(fragment);
577  for (int d = 0; d < Base::Iterations::kD; ++d) {
578  for (int h = 0; h < Base::Iterations::kH; ++h) {
579  for (int w = 0; w < Base::Iterations::kW; ++w) {
580  for (int c = 0; c < Base::Iterations::kC; ++c) {
581  if (valid(d, h, w, c)) {
582  load_element(
583  reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
584  d,
585  h,
586  w,
587  c);
588  }
589  }
590  if (w < Base::Iterations::kW - 1) {
591  inc_w();
592  }
593  }
594  if (h < Base::Iterations::kH - 1) {
595  inc_h();
596  }
597  }
598  if (d < Base::Iterations::kD - 1) {
599  inc_d();
600  }
601  }
602  inc_advance();
603  }
604 
605  template <typename Fragment>
607  typename Base::FragmentIterator frag_iterator(fragment);
608  for (int d = 0; d < Base::Iterations::kD; ++d) {
609  for (int h = 0; h < Base::Iterations::kH; ++h) {
610  for (int w = 0; w < Base::Iterations::kW; ++w) {
611  for (int c = 0; c < Base::Iterations::kC; ++c) {
612  if (valid(d, h, w, c)) {
614  reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
615  d,
616  h,
617  w,
618  c);
619  }
620  }
621  if (w < Base::Iterations::kW - 1) {
622  inc_w();
623  }
624  }
625  if (h < Base::Iterations::kH - 1) {
626  inc_h();
627  }
628  }
629  if (d < Base::Iterations::kD - 1) {
630  inc_d();
631  }
632  }
633  inc_advance();
634  }
635 };
636 
638 
639 } // namespace gemm
640 } // namespace cutlass
Definition: gemm_global_tile.h:120
Shape< 0, Threads::kH, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:92
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:529
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:362
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:434
Definition: convert.h:33
cutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount > PredicateVector
Definition: gemm_global_tile.h:196
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:180
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &threadblock_offset, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:270
T type
Definition: platform.h:377
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:461
Base::Params BaseParams
Iterator parameters type.
Definition: gemm_global_tile.h:199
CUTLASS_HOST_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: gemm_global_tile.h:512
Index_ Index
The index.
Definition: gemm_global_tile.h:421
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
ReshapeTile< Tile_, kAccessSize_ >::Tile VectorizedTile
The vectorized tile shape.
Definition: gemm_global_tile.h:86
GemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:402
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:412
Definition: gemm_global_tile.h:70
Scalar_ * Pointer
The pointer.
Definition: gemm_global_tile.h:78
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:199
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:357
Definition: load_store.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
Shape< 1, 1, VectorizedTile::kC > ThreadsDelta
The relative offset between two elements in the H/W dimension in adjacent threads.
Definition: gemm_global_tile.h:90
GemmMultiplicandTraits< Tile, kOperand, kLayout > MultiplicandTraits
Definition: gemm_global_tile.h:103
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:425
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_global_tile.h:82
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:292
TileIteratorBase< TileTraits_, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:409
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:196
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:138
Index predicate_inc_h
Definition: gemm_global_tile.h:436
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:584
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:434
long long inc_d
Definition: tile_iterator.h:223
Tile_ Tile
The tile shape.
Definition: gemm_global_tile.h:84
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:464
Definition: tile_iterator.h:65
long long inc_advance
Definition: tile_iterator.h:227
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads and increments iterator.
Definition: gemm_global_tile.h:575
TileLoadIterator< TileTraits_, typename TileTraits_::Scalar, TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:178
Definition: gemm_global_tile.h:201
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block, int offset=0, int pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:489
Definition: matrix_traits.h:357
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:181
C++ features that may be otherwise unimplemented for CUDA device functions.
Definition: gemm_global_tile.h:163
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: gemm_global_tile.h:133
Kind
Definition: load_store.h:39
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:287
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: gemm_global_tile.h:194
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:188
static int const kStrideH
The stride in the H dimension.
Definition: gemm_global_tile.h:136
static int const kH
The height of the cube.
Definition: shape.h:68
Definition: load_store.h:178
Shape< Threads_::kD, Threads_::kH *Threads_::kW/Tile_::kW, Tile_::kW, 1 > Threads
Definition: gemm_global_tile.h:59
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:436
static GemmOperand::Kind const kOperand
Identity of the operand.
Definition: gemm_global_tile.h:72
Index stride_h
Definition: tile_iterator.h:220
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Test the validity of the.
Definition: gemm_global_tile.h:566
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:521
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:399
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long batch_stride, Index ldm, Index bound, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: gemm_global_tile.h:441
PredicateVector predicates
The predicates.
Definition: gemm_global_tile.h:241
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_global_tile.h:76
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:285
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:341
Base::Fragment Fragment
Fragment type loaded by the iterator.
Definition: gemm_global_tile.h:184
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:419
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:151
Definition: gemm_operand.h:67
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:106
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:283
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:771
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:344
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:606
Base::Threads Threads
Definition: gemm_global_tile.h:142
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:432
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:289
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:108
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:538
Index inc_h
Definition: tile_iterator.h:224
Shape< 0, 0, Threads::kW *ThreadsDelta::kW, kAccessSize > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: gemm_global_tile.h:95
Statically sized array of bits implementing.
Definition: predicate_vector.h:105
Definition: vector.h:62
Definition: load_store.h:60
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:423
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Base::ImmediateOffsetStrides ImmediateOffsetStrides
Definition: gemm_global_tile.h:146
long long stride_d
Definition: tile_iterator.h:219
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:415
Index inc_h
Definition: gemm_global_tile.h:434
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: gemm_global_tile.h:473
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:428
GemmGlobalIteratorAb< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:171
TileTraits_::Tile Tile
The tile.
Definition: gemm_global_tile.h:182
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
Shape< 1, VectorizedTile::kH/Threads::kH, VectorizedTile::kW/Threads::kW, VectorizedTile::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_global_tile.h:101
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block_offset, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:476
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:523
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:516
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:514
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:493
Params params
Parameters.
Definition: gemm_global_tile.h:469
Definition: gemm_global_tile.h:396
Definition: matrix_traits.h:159
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:428
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:471
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:192
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:365
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Definition: gemm_global_tile.h:351
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:185
Parameters.
Definition: tile_iterator.h:491
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:149
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
long long stride_d
The stride in the D dimension.
Definition: gemm_global_tile.h:430
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: gemm_global_tile.h:353
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:680
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_global_tile.h:80
Tile_ Tile
Definition: reshape_tile.h:43
Definition: tile_iterator.h:65
Base::Iterations Iterations
Definition: gemm_global_tile.h:140
Index_ Index
The index.
Definition: gemm_global_tile.h:190
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:417
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the valid?
Definition: gemm_global_tile.h:335
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:431
Kind
Definition: matrix_traits.h:357
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:186
Threads_ Threads
Definition: gemm_global_tile.h:54
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, long long stride_d, Index stride_h)
Initializes params to load a strip-mined tile, given pointer and stride_h.
Definition: gemm_global_tile.h:203
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
add pointer offset
Definition: gemm_global_tile.h:571
Params params
The parameters.
Definition: gemm_global_tile.h:239
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &block_offset)
Definition: gemm_global_tile.h:243
The params.
Definition: gemm_global_tile.h:426
Base::ThreadsDelta ThreadsDelta
Definition: gemm_global_tile.h:144
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:237
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:467
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
CUTLASS_HOST_DEVICE void residue(Index k)
That&#39;s the residue! Update the predicates.
Definition: gemm_global_tile.h:306
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:438
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:74
CUTLASS_HOST_DEVICE void store_element(typename Base::AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: gemm_global_tile.h:552
Index stride_w
Definition: tile_iterator.h:221