76 template <
typename Tile_,
typename Delta_,
typename Iterations_,
typename ThreadOffset_>
94 template <
typename Traits_,
98 typename Index_ = int,
99 typename FragmentElement_ = Scalar_,
128 typedef typename Traits::Tile
Tile;
131 typedef typename Traits::Delta
Delta;
189 Index _inc_advance) {
239 CUTLASS_DEVICE
bool valid(
int d,
int h,
int w,
int c)
const {
return true; }
246 template <
typename PredicateIterator>
250 for (
int d = 0; d < Iterations::kD; ++d) {
251 bool enable_d = (d * Delta::kD + offset[0] < bounds[0]);
252 for (
int h = 0; h < Iterations::kH; ++h) {
253 bool enable_h = (h * Delta::kH + offset[1] < bounds[1]);
254 for (
int w = 0; w < Iterations::kW; ++w) {
255 bool enable_w = (w * Tile::kC * Delta::kW + offset[2] < bounds[2]);
256 predicate_it.set(d, h, w, 0, enable_d && enable_h && enable_w);
294 template <
typename Traits_,
298 typename Index_ = int,
299 typename FragmentElement_ = Scalar_,
301 typename Skew_ = Shape<0, 0, 0, 0> >
416 Index _inc_advance) {
419 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
446 template <
typename PredicateIterator>
474 Index block_offset_h = 0;
475 Index block_offset_w = 0;
477 block_offset_h = block_offset[1];
478 block_offset_w = block_offset[2];
480 block_offset_h = block_offset[2];
481 block_offset_w = block_offset[1];
496 int const offset = thread_offset_func()[2];
519 int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
520 if (
stage == Tile::kD - 1) {
532 template <
typename Fragment,
typename PredicateIterator>
536 for (
int d = 0; d < Iterations::kD; ++d) {
537 for (
int h = 0; h < Iterations::kH; ++h) {
538 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
541 reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
data(), 0);
544 if (w < Iterations::kW - 1) {
548 if (h < Iterations::kH - 1) {
552 if (d < Iterations::kD - 1) {
560 template <
typename Fragment>
567 template <
typename Fragment,
typename PredicateIterator>
574 template <
typename Fragment>
577 load(fragment, pred_it);
612 template <
typename Traits_,
616 typename Index_ = int,
617 typename FragmentElement_ = Scalar_,
619 typename Skew_ = Shape<0, 0, 0, 0> >
728 Index _inc_advance) {
731 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
758 template <
typename PredicateIterator>
798 int const offset = thread_offset_func()[2];
821 int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
822 if (
stage == Tile::kD - 1) {
834 template <
typename Fragment,
typename PredicateIterator>
838 for (
int d = 0; d < Iterations::kD; ++d) {
839 for (
int h = 0; h < Iterations::kH; ++h) {
840 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
843 reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
data(), 0);
845 if (w < Iterations::kW - 1) {
849 if (h < Iterations::kH - 1) {
853 if (d < Iterations::kD - 1) {
861 template <
typename Fragment>
868 template <
typename Fragment,
typename PredicateIterator>
875 template <
typename Fragment>
878 store(fragment, pred_it);
static int const kFragmentSize
The size of storage needed per fragment.
Definition: tile_iterator.h:149
static IteratorFragment::Kind const kIteratorFragment
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:334
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:816
FragmentConstIterator< Fragment, Iterations, AccessType > FragmentConstIterator
The fragment const iterator.
Definition: tile_iterator.h:158
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ > Base
Base class.
Definition: tile_iterator.h:637
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:682
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:367
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:346
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:533
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:355
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:649
FragmentIterator::FragmentShape FragmentShape
The shape of the fragment.
Definition: tile_iterator.h:160
Traits::ThreadOffset ThreadOffset
Thread offset.
Definition: tile_iterator.h:140
static IteratorFragment::Kind const kIteratorFragment
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:652
Skew_ Skew
Skew quantity.
Definition: tile_iterator.h:125
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:676
CUTLASS_HOST_DEVICE int initialize(SharedStorage &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:706
Enum to specify which memory space data resides in.
Definition: load_store.h:39
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:343
Kind
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:227
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:661
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Base::Storage SharedStorage
Storage object which may be stored to.
Definition: tile_iterator.h:694
A template defining Tile Traits Concept.
Definition: tile_iterator.h:77
CUTLASS_HOST_DEVICE Scalar const * data() const
Returns the current pointer.
Definition: tile_iterator.h:502
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ > Base
Base class.
Definition: tile_iterator.h:319
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &, SharedStorage &shared_storage, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:491
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:401
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:379
Params params
Parameters structure.
Definition: tile_iterator.h:745
static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
The load function.
Definition: load_store.h:59
CUTLASS_HOST_DEVICE int initialize(SharedStorage const &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:394
Definition: tile_iterator.h:382
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:325
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:361
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE void store(Fragment &fragment, PredicateIterator pred_it) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:869
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:331
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:561
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:468
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: tile_iterator.h:239
Iterations_ Iterations
Number of accesses performed.
Definition: tile_iterator.h:85
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:713
Params params
Parameters structure.
Definition: tile_iterator.h:433
Iterator that always returns true.
Definition: predicate_vector.h:308
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:643
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:640
Kind
Definition: load_store.h:40
Index stride_h
Definition: tile_iterator.h:172
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:862
Fragment< Scalar, ShapeCount< Tile >::kCount, kFragmentSize > Storage
The storage.
Definition: tile_iterator.h:152
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:425
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:807
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:352
CUTLASS_HOST_DEVICE int initialize()
Initializes params to default values.
Definition: tile_iterator.h:737
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:370
Index_ Index
Index type.
Definition: tile_iterator.h:122
static CUTLASS_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
The store function.
Definition: load_store.h:136
Index inc_h
Definition: tile_iterator.h:176
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
Base::Storage SharedStorage
Storage object that may be loaded from.
Definition: tile_iterator.h:376
Parameters.
Definition: tile_iterator.h:700
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector.
Definition: tile_iterator.h:759
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:302
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:322
static CUTLASS_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &offset=make_Coord(0, 0, 0))
Initializes a predicate vector.
Definition: tile_iterator.h:247
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:697
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:328
Traits::Tile Tile
Tile shape.
Definition: tile_iterator.h:128
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:156
int stage
The stage.
Definition: tile_iterator.h:751
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:183
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:679
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:646
Definition: load_store.h:41
CUTLASS_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:517
Kind
Definition: tile_iterator.h:67
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:373
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:409
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, Index _stride_d, Index _stride_h, Index _stride_w, Index _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:721
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:110
Base::Index Index
Index type.
Definition: tile_iterator.h:658
Scalar * pointer
Pointer to memory.
Definition: tile_iterator.h:702
Index inc_advance
Definition: tile_iterator.h:179
Definition: tile_iterator.h:67
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:185
Index stride_w
Definition: tile_iterator.h:173
CUTLASS_HOST_DEVICE TileLoadIterator()
Default constructor.
Definition: tile_iterator.h:464
Defines abstractions for efficiently loading and storing vectors to memory.
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:390
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:780
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:748
Traits::Iterations Iterations
Iterations.
Definition: tile_iterator.h:137
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:143
Tile_ Tile
Shape of the tile.
Definition: tile_iterator.h:79
Delta_ Delta
Number of steps between accesses along each dimension.
Definition: tile_iterator.h:82
CUTLASS_HOST_DEVICE int initialize(Index _stride_d, Index _stride_h, Index _stride_w)
Definition: tile_iterator.h:203
Index stride_d
Definition: tile_iterator.h:171
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:514
Base::Delta Delta
Delta.
Definition: tile_iterator.h:349
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:664
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:358
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:61
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:655
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:511
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
Traits::ImmediateOffsetStrides ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: tile_iterator.h:134
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:364
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:670
Defines a 1D vector of elements held in the registers of each thread.
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:102
static IteratorFragment::Kind const kIteratorFragment
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:116
Base::Delta Delta
Delta.
Definition: tile_iterator.h:667
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE Scalar * data() const
Returns the current pointer.
Definition: tile_iterator.h:804
ThreadOffset_ ThreadOffset
Functor that returns the logical coordinate of each entity's initial offset in the tile...
Definition: tile_iterator.h:88
Vectorize< FragmentElement, kAccessSize >::Type AccessType
The elements loaded/store by one instruction.
Definition: tile_iterator.h:146
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:505
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:810
CUTLASS_HOST_DEVICE void store(Fragment &fragment) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:876
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:508
Parameters.
Definition: tile_iterator.h:388
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:337
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:673
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:119
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:685
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:568
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &, SharedStorage &shared_storage, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:793
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector.
Definition: tile_iterator.h:447
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:154
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:575
Definition: tile_iterator.h:62
static IteratorAdvance::Kind const kAdvance
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:113
Index inc_w
Definition: tile_iterator.h:177
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:436
Traits::Delta Delta
Distance along each dimension.
Definition: tile_iterator.h:131
int stage
Stage argument enables wrapping after some number of tiles have been loaded.
Definition: tile_iterator.h:439
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:688
CUTLASS_HOST_DEVICE TileStoreIterator()
Default constructor.
Definition: tile_iterator.h:776
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:691
Scalar const * Pointer
The pointer type.
Definition: tile_iterator.h:385
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
Parameters to the iterator.
Definition: tile_iterator.h:170
Base::Index Index
Index type.
Definition: tile_iterator.h:340
CUTLASS_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:819
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment, PredicateIterator pred_it)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:835
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:813
PredicateVector< ShapeCount< Iterations >::kCount > PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:163
Definition: tile_iterator.h:67
Scalar_ Scalar
Scalar element.
Definition: tile_iterator.h:107
Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix.
Definition: tile_iterator.h:66
Index inc_d
Definition: tile_iterator.h:175
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:620
Traits_ Traits
concept TileTraits
Definition: tile_iterator.h:104