73 template <
typename Tile_,
76 typename ThreadOffset_,
111 template <
typename Delta_>
125 return (iteration[0] * Delta::kD + offset[0] <
bounds[0]) &&
126 (iteration[1] * Delta::kH + offset[1] <
bounds[1]) &&
127 (iteration[2] * Delta::kW + offset[2] <
bounds[2]);
133 template <
typename T>
136 template <
typename Traits_,
140 typename Index_ = int,
141 typename FragmentElement_ = Scalar_,
170 typedef typename Traits::Tile
Tile;
173 typedef typename Traits::Delta
Delta;
245 long long _inc_advance)
268 long long _inc_advance) {
284 return initialize(stride[0], stride[1], stride[2]);
297 stride_w * Delta::kW * (Iterations::kW - 1);
314 stride_h * Delta::kH * (Iterations::kH - 1) +
315 stride_w * Delta::kW * (Iterations::kW - 1);
341 template <
typename PredicateIterator,
typename PredicateFunctor>
343 PredicateFunctor
const &predicate_func,
346 for (
int d = 0; d < Iterations::kD; ++d) {
348 for (
int h = 0; h < Iterations::kH; ++h) {
350 for (
int w = 0; w < Iterations::kW; ++w) {
351 bool enable = predicate_func(
make_Coord(d, h, w), offset);
352 predicate_it.set(enable);
391 template <
typename Traits_,
395 typename Index_ = int,
396 typename FragmentElement_ = Scalar_,
398 typename Skew_ = Shape<0, 0, 0, 0> >
405 FragmentElementType_,
414 FragmentElementType_,
525 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
575 Index _inc_advance) {
578 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
607 typename PredicateIterator>
622 typename PredicateIterator,
624 typename PredicateFunctor>
626 PredicateFunctor
const &functor,
701 int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
702 if (
stage == Tile::kD - 1) {
714 long long _offset = offset.template dot<long long>(
734 template <
typename Fragment,
typename PredicateIterator>
738 for (
int d = 0; d < Iterations::kD; ++d) {
739 for (
int h = 0; h < Iterations::kH; ++h) {
740 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
741 for (
int c = 0; c < Iterations::kC; ++c) {
744 reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
747 if (w < Iterations::kW - 1) {
751 if (h < Iterations::kH - 1) {
755 if (d < Iterations::kD - 1) {
763 template <
typename Fragment>
770 template <
typename Fragment,
typename PredicateIterator>
777 template <
typename Fragment>
780 load(fragment, pred_it);
784 template <
typename Fragment>
787 for (
int h = 0; h < Iterations::kH; ++h) {
788 for (
int w = 0; w < Iterations::kW; ++w) {
789 for (
int c = 0; c < Iterations::kC; ++c) {
790 load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
828 template <
typename Traits_,
832 typename Index_ = int,
833 typename FragmentElement_ = Scalar_,
835 typename Skew_ = Shape<0, 0, 0, 0> >
842 FragmentElementType_,
851 FragmentElementType_,
962 Index _inc_advance) {
963 initialize(ptr, _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
997 Index _inc_advance) {
1000 _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
1029 typename PredicateIterator>
1044 typename PredicateIterator,
1046 typename PredicateFunctor>
1048 PredicateFunctor
const &functor,
1104 int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
1105 if (
stage == Tile::kD - 1) {
1140 template <
typename Fragment,
typename PredicateIterator>
1144 for (
int d = 0; d < Iterations::kD; ++d) {
1145 for (
int h = 0; h < Iterations::kH; ++h) {
1146 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1147 for (
int c = 0; c < Iterations::kC; ++c) {
1150 reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1153 if (w < Iterations::kW - 1) {
1157 if (h < Iterations::kH - 1) {
1161 if (d < Iterations::kD - 1) {
1169 template <
typename Fragment>
1176 template <
typename Fragment,
typename PredicateIterator>
1183 template <
typename Fragment>
1186 store(fragment, pred_it);
1204 template <
typename Fragment,
typename PredicateIterator>
1208 for (
int d = 0; d < Iterations::kD; ++d) {
1209 for (
int h = 0; h < Iterations::kH; ++h) {
1210 for (
int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1211 for (
int c = 0; c < Iterations::kC; ++c) {
1214 reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1217 if (w < Iterations::kW - 1) {
1221 if (h < Iterations::kH - 1) {
1225 if (d < Iterations::kD - 1) {
1233 template <
typename Fragment>
1240 template <
typename Fragment,
typename PredicateIterator>
1247 template <
typename Fragment>
1250 load(fragment, pred_it);
1254 template <
typename Fragment>
1257 for (
int h = 0; h < Iterations::kH; ++h) {
1258 for (
int w = 0; w < Iterations::kW; ++w) {
1259 for (
int c = 0; c < Iterations::kC; ++c) {
1260 load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:990
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:683
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:1190
Vectorize< FragmentElement, kAccessSize >::Type AccessType
The elements loaded/store by one instruction.
Definition: tile_iterator.h:188
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:883
Delta_ Delta
Definition: tile_iterator.h:113
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:910
CUTLASS_HOST_DEVICE Params()
Initialize params to access storage object.
Definition: tile_iterator.h:501
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:644
CUTLASS_HOST_DEVICE int initialize(SharedStorage const &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:544
Tile_ Tile
Shape of the tile.
Definition: tile_iterator.h:80
Index_ Index
Index type.
Definition: tile_iterator.h:164
Defines a structure containing strides, bounds, and a pointer to tensor data.
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:449
CUTLASS_HOST_DEVICE int initialize(Coord< 4 > const &stride)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:283
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: tile_iterator.h:334
Skew_ Skew
Skew quantity.
Definition: tile_iterator.h:167
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:461
Enum to specify which memory space data resides in.
Definition: load_store.h:38
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1234
Base::Index Index
Index type.
Definition: tile_iterator.h:877
Base::Storage SharedStorage
Storage object that may be loaded from.
Definition: tile_iterator.h:476
int stage
The stage.
Definition: tile_iterator.h:1020
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:443
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:568
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:199
Scalar * Pointer
Pointer to underlying type.
Definition: tile_iterator.h:919
Traits::ThreadOffset ThreadOffset
Thread offset.
Definition: tile_iterator.h:182
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE int initialize(long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, long long _inc_advance)
Initializes params.
Definition: tile_iterator.h:262
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
Shape< 0, 0, 0, 0 > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: tile_iterator.h:102
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:470
A template defining Tile Traits Concept.
Definition: tile_iterator.h:78
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:425
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:416
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:608
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:196
Traits::Iterations Iterations
Iterations.
Definition: tile_iterator.h:179
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:785
Base::Delta Delta
Delta.
Definition: tile_iterator.h:446
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:735
Definition: load_store.h:48
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:584
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:434
Traits::ImmediateOffsetStrides ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: tile_iterator.h:176
long long inc_d
Definition: tile_iterator.h:223
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:452
CUTLASS_HOST_DEVICE int initialize()
Initializes params to default values.
Definition: tile_iterator.h:1006
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:464
Definition: tile_iterator.h:65
long long inc_advance
Definition: tile_iterator.h:227
Base::Storage SharedStorage
Storage object which may be stored to.
Definition: tile_iterator.h:913
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:871
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:1030
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:1090
ThreadOffset_ ThreadOffset
Functor that returns the logical coordinate of each entity's initial offset in the tile...
Definition: tile_iterator.h:99
Iterator that always returns true.
Definition: predicate_vector.h:309
CUTLASS_HOST_DEVICE Params(Coord< 4 > const &stride)
Constructs params with a stride vector.
Definition: tile_iterator.h:256
Scalar * pointer
Pointer to memory.
Definition: tile_iterator.h:927
CUTLASS_HOST_DEVICE Params(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w)
Definition: tile_iterator.h:949
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:895
Kind
Definition: load_store.h:39
PredicateVector< ShapeCount< Iterations >::kCount > PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:206
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: tile_iterator.h:725
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:901
TensorRef< Scalar, 4 > TensorRef
Tensor reference for the store iterator.
Definition: tile_iterator.h:922
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:419
TensorRef< Scalar const, 4 > TensorRef
Tensor reference for the load iterator.
Definition: tile_iterator.h:488
Definition: load_store.h:178
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1248
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:625
CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: tile_iterator.h:1127
FragmentIterator::FragmentShape FragmentShape
The shape of the fragment.
Definition: tile_iterator.h:203
Scalar const * Pointer
The pointer type.
Definition: tile_iterator.h:485
static IteratorAdvance::Kind const kAdvance
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:155
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:62
Index stride_h
Definition: tile_iterator.h:220
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
Parameters.
Definition: tile_iterator.h:925
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
Traits_ Traits
concept TileTraits
Definition: tile_iterator.h:146
Params params
Parameters structure.
Definition: tile_iterator.h:1014
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:907
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:399
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:916
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:699
CUTLASS_HOST_DEVICE int initialize()
Gotta have this.
Definition: tile_iterator.h:321
Kind
Definition: load_store.h:48
CUTLASS_HOST_DEVICE RegularTilePredicateFunctor(Coord< 3 > _bounds)
Constructs a predicate functor given the bounds of a tensor.
Definition: tile_iterator.h:120
CUTLASS_HOST_DEVICE TileLoadIterator()
Default constructor.
Definition: tile_iterator.h:640
Definition: load_store.h:40
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:479
Params params
Parameters structure.
Definition: tile_iterator.h:592
FragmentConstIterator< Fragment, Iterations, AccessType > FragmentConstIterator
The fragment const iterator.
Definition: tile_iterator.h:201
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:552
CUTLASS_HOST_DEVICE TileLoadIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:713
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:686
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:898
Definition: tile_iterator.h:482
Definition: tile_iterator.h:134
Iterations_ Iterations
Number of accesses performed.
Definition: tile_iterator.h:86
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:183
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:771
static int const kAccessSize
Access size.
Definition: tile_iterator.h:105
Fragment< Scalar, ShapeCount< Tile >::kCount, kFragmentSize > Storage
The storage.
Definition: tile_iterator.h:194
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:880
Delta_ Delta
Number of steps between accesses along each dimension.
Definition: tile_iterator.h:83
Defines abstractions for efficiently loading and storing vectors to memory.
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:422
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:455
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:868
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1079
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:1099
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Params(Scalar *ptr)
Definition: tile_iterator.h:939
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1066
Index inc_h
Definition: tile_iterator.h:224
CUTLASS_HOST_DEVICE Params()
Constructs params.
Definition: tile_iterator.h:235
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:853
Definition: load_store.h:60
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:856
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
CUTLASS_HOST_DEVICE Params(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:505
long long stride_d
Definition: tile_iterator.h:219
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:904
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:64
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:509
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initialize params to access storage object.
Definition: tile_iterator.h:515
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
CUTLASS_HOST_DEVICE TileStoreIterator()
Default constructor.
Definition: tile_iterator.h:1062
Definition: load_store.h:48
CUTLASS_HOST_DEVICE int initialize(TensorRef const &ref)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:537
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:859
Defines a 1D vector of elements held in the registers of each thread.
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w)
Initialize params to access storage object.
Definition: tile_iterator.h:530
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:493
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1141
Definition: tile_iterator.h:65
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:428
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1184
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:1017
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:889
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:874
Functor computing a predicate given the logical position of an access.
Definition: tile_iterator.h:112
Traits::Tile Tile
Tile shape.
Definition: tile_iterator.h:170
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:1124
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:185
Parameters.
Definition: tile_iterator.h:491
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:473
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1255
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:458
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:440
Base::Delta Delta
Delta.
Definition: tile_iterator.h:886
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:560
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:1093
CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1177
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:680
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:778
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:892
Definition: tile_iterator.h:65
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:431
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:677
static CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &predicate_func, Coord< 3 > const &offset)
Initializes a predicate vector.
Definition: tile_iterator.h:342
Scalar_ Scalar
Scalar element.
Definition: tile_iterator.h:149
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:595
Coord< 3 > bounds
Dimensions of the bounding volume.
Definition: tile_iterator.h:116
Traits::Delta Delta
Distance along each dimension.
Definition: tile_iterator.h:173
static int const kFragmentSize
The size of storage needed per fragment.
Definition: tile_iterator.h:191
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1170
CUTLASS_HOST_DEVICE Params()
Definition: tile_iterator.h:935
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:975
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:862
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1241
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:943
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:674
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:764
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
Parameters to the iterator.
Definition: tile_iterator.h:213
CUTLASS_HOST_DEVICE TileStoreIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:1116
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:1102
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:467
CUTLASS_HOST_DEVICE bool operator()(Coord< 3 > iteration, Coord< 3 > offset) const
Computes the predicate given the logical position of an access.
Definition: tile_iterator.h:124
CUTLASS_HOST_DEVICE Params(Scalar *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Definition: tile_iterator.h:955
Base::Index Index
Index type.
Definition: tile_iterator.h:437
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:1096
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:865
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:982
int stage
Stage argument enables wrapping after some number of tiles have been loaded.
Definition: tile_iterator.h:598
CUTLASS_HOST_DEVICE int initialize(long long _stride_d, Index _stride_h, Index _stride_w)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:289
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &, Scalar const *ptr, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:659
CUTLASS_HOST_DEVICE int initialize(SharedStorage &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:968
Index inc_w
Definition: tile_iterator.h:225
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:1047
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:836
CUTLASS_HOST_DEVICE Params(long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, long long _inc_advance)
Constructs params.
Definition: tile_iterator.h:239
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:723
Index stride_w
Definition: tile_iterator.h:221
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1205