Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tile_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cutlass/fragment.h>
32 #include <cutlass/load_store.h>
34 #include <cutlass/vector.h>
35 
36 namespace cutlass {
37 
39 
58 
62  enum Kind { kD, kH, kW };
63 };
64 
68 };
69 
71 
76 template <typename Tile_, typename Delta_, typename Iterations_, typename ThreadOffset_>
77 struct TileTraits {
79  typedef Tile_ Tile;
80 
82  typedef Delta_ Delta;
83 
85  typedef Iterations_ Iterations;
86 
88  typedef ThreadOffset_ ThreadOffset;
89 };
90 
92 
94 template <typename Traits_,
95  typename Scalar_,
98  typename Index_ = int,
99  typename FragmentElement_ = Scalar_,
101  typename Skew_ = Shape<0, 0, 0, 0> >
104  typedef Traits_ Traits;
105 
107  typedef Scalar_ Scalar;
108 
110  typedef FragmentElement_ FragmentElement;
111 
113  static IteratorAdvance::Kind const kAdvance = Advance_;
114 
116  static IteratorFragment::Kind const kIteratorFragment = IteratorFragment_;
117 
120 
122  typedef Index_ Index;
123 
125  typedef Skew_ Skew;
126 
128  typedef typename Traits::Tile Tile;
129 
131  typedef typename Traits::Delta Delta;
132 
134  typedef typename Traits::ImmediateOffsetStrides ImmediateOffsetStrides;
135 
137  typedef typename Traits::Iterations Iterations;
138 
140  typedef typename Traits::ThreadOffset ThreadOffset;
141 
143  static int const kAccessSize = Tile::kC;
144 
147 
149  static int const kFragmentSize =
161 
164 
165  //
166  // Params struct
167  //
168 
170  struct Params {
174 
178 
180 
183  int initialize(Index _stride_d,
184  Index _stride_h,
185  Index _stride_w,
186  Index _inc_d,
187  Index _inc_h,
188  Index _inc_w,
189  Index _inc_advance) {
190  stride_d = _stride_d;
191  stride_h = _stride_h;
192  stride_w = _stride_w;
193 
194  inc_d = _inc_d;
195  inc_h = _inc_h;
196  inc_w = _inc_w;
197  inc_advance = _inc_advance;
198 
199  return 0;
200  }
201 
203  int initialize(Index _stride_d, Index _stride_h, Index _stride_w) {
204  stride_d = _stride_d;
205  stride_h = _stride_h;
206  stride_w = _stride_w;
207 
208  inc_w = stride_w * Delta::kW;
209  inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
210 
211  if (kAdvance == IteratorAdvance::kH) {
212  // Advance in the H dimension.
213  inc_d = 0;
214  } else if (kAdvance == IteratorAdvance::kW) {
215  // Advance in the W dimension.
216  inc_d = stride_w * Tile::kW - stride_h * Tile::kH;
217  } else {
218  // Advance in the D dimension.
219  inc_d = stride_d;
220  }
221 
222  inc_advance = 0;
223 
224  return 0;
225  }
226 
228  stride_d = 0;
229  stride_h = 0;
230  stride_w = 1;
231 
232  inc_d = inc_h = inc_w = inc_advance = 0;
233 
234  return 0;
235  }
236  };
237 
239  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
240 
241  //
242  // Static function members
243  //
244 
246  template <typename PredicateIterator>
247  CUTLASS_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
248  Coord<3> const &bounds,
249  Coord<3> const &offset = make_Coord(0, 0, 0)) {
250  for (int d = 0; d < Iterations::kD; ++d) {
251  bool enable_d = (d * Delta::kD + offset[0] < bounds[0]);
252  for (int h = 0; h < Iterations::kH; ++h) {
253  bool enable_h = (h * Delta::kH + offset[1] < bounds[1]);
254  for (int w = 0; w < Iterations::kW; ++w) {
255  bool enable_w = (w * Tile::kC * Delta::kW + offset[2] < bounds[2]);
256  predicate_it.set(d, h, w, 0, enable_d && enable_h && enable_w);
257  }
258  }
259  }
260  }
261 };
262 
264 
288 
294 template <typename Traits_,
295  typename Scalar_,
298  typename Index_ = int,
299  typename FragmentElement_ = Scalar_,
301  typename Skew_ = Shape<0, 0, 0, 0> >
302 struct TileLoadIterator : public TileIteratorBase<Traits_,
303  Scalar_,
304  Advance_,
305  MemorySpace,
306  Index_,
307  FragmentElement_,
308  IteratorFragment_,
309  Skew_> {
311  typedef TileIteratorBase<Traits_,
312  Scalar_,
313  Advance_,
314  MemorySpace,
315  Index_,
316  FragmentElement_,
317  IteratorFragment_,
318  Skew_>
320 
322  typedef typename Base::Traits Traits;
323 
325  typedef typename Base::Scalar Scalar;
326 
329 
332 
335 
338 
340  typedef typename Base::Index Index;
341 
343  typedef typename Base::Skew Skew;
344 
346  typedef typename Base::Tile Tile;
347 
349  typedef typename Base::Delta Delta;
350 
352  typedef typename Base::Iterations Iterations;
353 
356 
359 
361  typedef typename Base::AccessType AccessType;
362 
364  typedef typename Base::Fragment Fragment;
365 
368 
371 
374 
376  typedef typename Base::Storage SharedStorage;
377 
379  typedef typename Base::Params BaseParams;
380 
382  enum { kRequiresLoadFence = Tile::kD == 1 };
383 
385  typedef Scalar const *Pointer;
386 
388  struct Params : public BaseParams {
390  Scalar const *pointer;
391 
394  int initialize(SharedStorage const &storage) {
395  pointer = &storage[0];
396  return 0;
397  }
398 
403  pointer = ptr;
404  return 0;
405  }
406 
409  int initialize(Scalar const *ptr,
410  Index _stride_d,
411  Index _stride_h,
412  Index _stride_w,
413  Index _inc_d,
414  Index _inc_h,
415  Index _inc_w,
416  Index _inc_advance) {
417  pointer = ptr;
419  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
420  return 0;
421  }
422 
423  // Initializes params to default values
426  };
427 
428  //
429  // Data members
430  //
431 
433  Params params;
434 
437 
439  int stage;
440 
441  //
442  // Static member functions
443  //
444 
446  template <typename PredicateIterator>
447  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
448  Coord<3> const &bounds,
449  Coord<3> const &block_offset = make_Coord(0,
450  0,
451  0)) {
453  predicate_it,
454  bounds,
455  block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
456  }
457 
458  //
459  // Methods
460  //
461 
465 
468  TileLoadIterator(Params const &_params,
469  Coord<3> const &block_offset = make_Coord(0, 0, 0),
470  ThreadOffset thread_offset_func = ThreadOffset())
471  : params(_params), stage(0) {
472  thread_offset = thread_offset_func();
473 
474  Index block_offset_h = 0;
475  Index block_offset_w = 0;
476  if (kAdvance == IteratorAdvance::kH) {
477  block_offset_h = block_offset[1];
478  block_offset_w = block_offset[2];
479  } else {
480  block_offset_h = block_offset[2];
481  block_offset_w = block_offset[1];
482  }
483 
484  params.pointer += block_offset[0] * params.stride_d +
485  (block_offset_h + thread_offset[1]) * params.stride_h +
486  (block_offset_w + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
487  }
488 
491  TileLoadIterator(Params const &,
492  SharedStorage &shared_storage,
493  Coord<3> const &block_offset = make_Coord(0, 0, 0),
494  ThreadOffset thread_offset_func = ThreadOffset())
495  : stage(0) {
496  int const offset = thread_offset_func()[2];
497  params.pointer = &shared_storage[offset];
498  }
499 
502  Scalar const *data() const { return params.pointer; }
503 
506 
509 
512 
515 
517  CUTLASS_DEVICE void inc_stage() {
518  if (Tile::kD > 1) {
519  int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
520  if (stage == Tile::kD - 1) {
521  params.pointer -= (Tile::kD - 1) * kStageSize;
522  stage = 0;
523  } else {
524  params.pointer += kStageSize;
525  stage = stage + 1;
526  }
527  }
528  }
529 
530  public:
532  template <typename Fragment, typename PredicateIterator>
533  CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
534  FragmentIterator frag_iterator(fragment);
535 
536  for (int d = 0; d < Iterations::kD; ++d) {
537  for (int h = 0; h < Iterations::kH; ++h) {
538  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
539  if (*pred_it) {
541  reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
542  }
543 
544  if (w < Iterations::kW - 1) {
545  inc_w();
546  }
547  }
548  if (h < Iterations::kH - 1) {
549  inc_h();
550  }
551  }
552  if (d < Iterations::kD - 1) {
553  inc_d();
554  }
555  }
556  inc_advance();
557  }
558 
560  template <typename Fragment>
562  typename PredicateVector::TrivialIterator pred_it;
563  load_post_increment(fragment, pred_it);
564  }
565 
567  template <typename Fragment, typename PredicateIterator>
568  CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
569  TileLoadIterator _load_it(*this);
570  _load_it.load_post_increment(fragment, pred_it);
571  }
572 
574  template <typename Fragment>
575  CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
576  typename PredicateVector::TrivialIterator pred_it;
577  load(fragment, pred_it);
578  }
579 };
580 
582 
606 
612 template <typename Traits_,
613  typename Scalar_,
616  typename Index_ = int,
617  typename FragmentElement_ = Scalar_,
619  typename Skew_ = Shape<0, 0, 0, 0> >
620 struct TileStoreIterator : public TileIteratorBase<Traits_,
621  Scalar_,
622  Advance_,
623  MemorySpace,
624  Index_,
625  FragmentElement_,
626  IteratorFragment_,
627  Skew_> {
629  typedef TileIteratorBase<Traits_,
630  Scalar_,
631  Advance_,
632  MemorySpace,
633  Index_,
634  FragmentElement_,
635  IteratorFragment_,
636  Skew_>
638 
640  typedef typename Base::Traits Traits;
641 
643  typedef typename Base::Scalar Scalar;
644 
647 
650 
653 
656 
658  typedef typename Base::Index Index;
659 
661  typedef typename Base::Skew Skew;
662 
664  typedef typename Base::Tile Tile;
665 
667  typedef typename Base::Delta Delta;
668 
670  typedef typename Base::Iterations Iterations;
671 
674 
677 
679  typedef typename Base::AccessType AccessType;
680 
682  typedef typename Base::Fragment Fragment;
683 
686 
689 
692 
694  typedef typename Base::Storage SharedStorage;
695 
697  typedef typename Base::Params BaseParams;
698 
700  struct Params : public BaseParams {
703 
706  int initialize(SharedStorage &storage) {
707  pointer = &storage[0];
708  return 0;
709  }
710 
715  pointer = ptr;
716  return 0;
717  }
718 
721  int initialize(Scalar *ptr,
722  Index _stride_d,
723  Index _stride_h,
724  Index _stride_w,
725  Index _inc_d,
726  Index _inc_h,
727  Index _inc_w,
728  Index _inc_advance) {
729  pointer = ptr;
731  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
732  return 0;
733  }
734 
738  };
739 
740  //
741  // Data members
742  //
743 
746 
749 
751  int stage;
752 
753  //
754  // Static member functions
755  //
756 
758  template <typename PredicateIterator>
759  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
760  Coord<3> const &bounds,
761  Coord<3> const &block_offset = make_Coord(0,
762  0,
763  0)) {
765  predicate_it,
766  bounds,
767  block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
768  }
769 
770  //
771  // Methods
772  //
773 
777 
780  TileStoreIterator(Params const &_params,
781  Coord<3> const &block_offset = make_Coord(0, 0, 0),
782  ThreadOffset thread_offset_func = ThreadOffset())
783  : params(_params), stage(0) {
784  thread_offset = thread_offset_func();
785 
786  params.pointer += block_offset[0] * params.stride_d +
787  (block_offset[1] + thread_offset[1]) * params.stride_h +
788  (block_offset[2] + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
789  }
790 
794  SharedStorage &shared_storage,
795  Coord<3> const &block_offset = make_Coord(0, 0, 0),
796  ThreadOffset thread_offset_func = ThreadOffset())
797  : stage(0) {
798  int const offset = thread_offset_func()[2];
799  params.pointer = &shared_storage[offset];
800  }
801 
804  Scalar *data() const { return params.pointer; }
805 
808 
811 
814 
817 
819  CUTLASS_DEVICE void inc_stage() {
820  if (Tile::kD > 1) {
821  int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
822  if (stage == Tile::kD - 1) {
823  params.pointer -= (Tile::kD - 1) * kStageSize;
824  stage = 0;
825  } else {
826  params.pointer += kStageSize;
827  stage = stage + 1;
828  }
829  }
830  }
831 
832  public:
834  template <typename Fragment, typename PredicateIterator>
835  CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment, PredicateIterator pred_it) {
836  FragmentIterator frag_iterator(fragment);
837 
838  for (int d = 0; d < Iterations::kD; ++d) {
839  for (int h = 0; h < Iterations::kH; ++h) {
840  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
841  if (*pred_it) {
843  reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
844  }
845  if (w < Iterations::kW - 1) {
846  inc_w();
847  }
848  }
849  if (h < Iterations::kH - 1) {
850  inc_h();
851  }
852  }
853  if (d < Iterations::kD - 1) {
854  inc_d();
855  }
856  }
857  inc_advance();
858  }
859 
861  template <typename Fragment>
863  typename PredicateVector::TrivialIterator pred_it;
864  store_post_increment(fragment, pred_it);
865  }
866 
868  template <typename Fragment, typename PredicateIterator>
869  CUTLASS_HOST_DEVICE void store(Fragment &fragment, PredicateIterator pred_it) const {
870  TileStoreIterator _store_it(*this);
871  _store_it.store_post_increment(fragment, pred_it);
872  }
873 
875  template <typename Fragment>
876  CUTLASS_HOST_DEVICE void store(Fragment &fragment) const {
877  typename PredicateVector::TrivialIterator pred_it;
878  store(fragment, pred_it);
879  }
880 };
881 }
static int const kFragmentSize
The size of storage needed per fragment.
Definition: tile_iterator.h:149
static IteratorFragment::Kind const kIteratorFragment
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:334
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:816
FragmentConstIterator< Fragment, Iterations, AccessType > FragmentConstIterator
The fragment const iterator.
Definition: tile_iterator.h:158
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ > Base
Base class.
Definition: tile_iterator.h:637
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:682
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:367
Definition: convert.h:33
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:346
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:533
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:355
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:649
FragmentIterator::FragmentShape FragmentShape
The shape of the fragment.
Definition: tile_iterator.h:160
Traits::ThreadOffset ThreadOffset
Thread offset.
Definition: tile_iterator.h:140
static IteratorFragment::Kind const kIteratorFragment
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:652
Skew_ Skew
Skew quantity.
Definition: tile_iterator.h:125
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:676
CUTLASS_HOST_DEVICE int initialize(SharedStorage &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:706
Enum to specify which memory space data resides in.
Definition: load_store.h:39
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:343
Kind
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:227
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:661
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Base::Storage SharedStorage
Storage object which may be stored to.
Definition: tile_iterator.h:694
A template defining Tile Traits Concept.
Definition: tile_iterator.h:77
CUTLASS_HOST_DEVICE Scalar const * data() const
Returns the current pointer.
Definition: tile_iterator.h:502
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ > Base
Base class.
Definition: tile_iterator.h:319
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &, SharedStorage &shared_storage, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:491
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:401
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:379
Params params
Parameters structure.
Definition: tile_iterator.h:745
static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
The load function.
Definition: load_store.h:59
CUTLASS_HOST_DEVICE int initialize(SharedStorage const &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:394
Definition: tile_iterator.h:382
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:325
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:361
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE void store(Fragment &fragment, PredicateIterator pred_it) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:869
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:331
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:561
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:468
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: tile_iterator.h:239
Iterations_ Iterations
Number of accesses performed.
Definition: tile_iterator.h:85
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:713
Params params
Parameters structure.
Definition: tile_iterator.h:433
Iterator that always returns true.
Definition: predicate_vector.h:308
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:643
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:640
Kind
Definition: load_store.h:40
Index stride_h
Definition: tile_iterator.h:172
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:862
Fragment< Scalar, ShapeCount< Tile >::kCount, kFragmentSize > Storage
The storage.
Definition: tile_iterator.h:152
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:425
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:807
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:352
CUTLASS_HOST_DEVICE int initialize()
Initializes params to default values.
Definition: tile_iterator.h:737
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:370
Index_ Index
Index type.
Definition: tile_iterator.h:122
static CUTLASS_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
The store function.
Definition: load_store.h:136
Index inc_h
Definition: tile_iterator.h:176
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
Base::Storage SharedStorage
Storage object that may be loaded from.
Definition: tile_iterator.h:376
Parameters.
Definition: tile_iterator.h:700
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector.
Definition: tile_iterator.h:759
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:302
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:322
static CUTLASS_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &offset=make_Coord(0, 0, 0))
Initializes a predicate vector.
Definition: tile_iterator.h:247
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:697
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:328
Traits::Tile Tile
Tile shape.
Definition: tile_iterator.h:128
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:156
int stage
The stage.
Definition: tile_iterator.h:751
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:183
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:679
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:646
Definition: load_store.h:41
CUTLASS_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:517
Kind
Definition: tile_iterator.h:67
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:373
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:409
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:721
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:110
Base::Index Index
Index type.
Definition: tile_iterator.h:658
Scalar * pointer
Pointer to memory.
Definition: tile_iterator.h:702
Index inc_advance
Definition: tile_iterator.h:179
Definition: tile_iterator.h:67
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:185
Index stride_w
Definition: tile_iterator.h:173
CUTLASS_HOST_DEVICE TileLoadIterator()
Default constructor.
Definition: tile_iterator.h:464
Defines abstractions for efficiently loading and storing vectors to memory.
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:390
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:780
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:748
Traits::Iterations Iterations
Iterations.
Definition: tile_iterator.h:137
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:143
Tile_ Tile
Shape of the tile.
Definition: tile_iterator.h:79
Delta_ Delta
Number of steps between accesses along each dimension.
Definition: tile_iterator.h:82
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w)
Definition: tile_iterator.h:203
Index stride_d
Definition: tile_iterator.h:171
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:514
Definition: vector.h:61
Base::Delta Delta
Delta.
Definition: tile_iterator.h:349
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:664
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:358
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:61
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:655
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:511
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
Traits::ImmediateOffsetStrides ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: tile_iterator.h:134
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:364
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:670
Defines a 1D vector of elements held in the registers of each thread.
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:102
static IteratorFragment::Kind const kIteratorFragment
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:116
Base::Delta Delta
Delta.
Definition: tile_iterator.h:667
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE Scalar * data() const
Returns the current pointer.
Definition: tile_iterator.h:804
ThreadOffset_ ThreadOffset
Functor that returns the logical coordinate of each entity&#39;s initial offset in the tile...
Definition: tile_iterator.h:88
Vectorize< FragmentElement, kAccessSize >::Type AccessType
The elements loaded/store by one instruction.
Definition: tile_iterator.h:146
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:505
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:810
CUTLASS_HOST_DEVICE void store(Fragment &fragment) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:876
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:508
Parameters.
Definition: tile_iterator.h:388
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:337
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:673
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:119
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:685
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:568
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &, SharedStorage &shared_storage, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:793
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector.
Definition: tile_iterator.h:447
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:154
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:575
Definition: tile_iterator.h:62
static IteratorAdvance::Kind const kAdvance
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:113
Index inc_w
Definition: tile_iterator.h:177
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:436
Traits::Delta Delta
Distance along each dimension.
Definition: tile_iterator.h:131
int stage
Stage argument enables wrapping after some number of tiles have been loaded.
Definition: tile_iterator.h:439
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:688
CUTLASS_HOST_DEVICE TileStoreIterator()
Default constructor.
Definition: tile_iterator.h:776
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:691
Scalar const * Pointer
The pointer type.
Definition: tile_iterator.h:385
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
Parameters to the iterator.
Definition: tile_iterator.h:170
Base::Index Index
Index type.
Definition: tile_iterator.h:340
CUTLASS_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:819
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment, PredicateIterator pred_it)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:835
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:813
PredicateVector< ShapeCount< Iterations >::kCount > PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:163
Definition: tile_iterator.h:67
Scalar_ Scalar
Scalar element.
Definition: tile_iterator.h:107
Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix.
Definition: tile_iterator.h:66
Index inc_d
Definition: tile_iterator.h:175
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:620
Traits_ Traits
concept TileTraits
Definition: tile_iterator.h:104