52 template <
typename Tile_,
typename Threads_,
bool = (Tile_::kW < Threads_::kW)>
53 struct ReshapeThreads {
54 typedef Threads_ Threads;
57 template <
typename Tile_,
typename Threads_>
59 typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1>
Threads;
98 VectorizedTile::kH / Threads::kH,
99 VectorizedTile::kW / Threads::kW,
112 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
119 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kStr
ideH_,
int kAccessSize_>
121 MatrixLayout::kColumnMajor,
155 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
162 template <
typename TileTraits_,
typename Index_ =
int>
165 typename TileTraits_::Scalar,
166 TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
167 : IteratorAdvance::kW,
168 MemorySpace::kGlobal,
173 typename TileTraits_::Scalar,
182 typedef typename TileTraits_::Tile
Tile;
186 typedef typename TileTraits_::Scalar
Scalar;
188 typedef typename TileTraits_::Threads
Threads;
212 if (Base::Delta::kD > 0) {
221 }
else if (Base::Delta::kD > 0) {
223 (Base::Iterations::kH - 1) *
inc_h -
224 (Base::Iterations::kD - 1) * Base::Delta::kD *
stride_h;
227 (Base::Iterations::kH - 1) *
inc_h;
248 for (
int d = 0; d < Base::Iterations::kD; ++d) {
249 for (
int h = 0; h < Base::Iterations::kH; ++h) {
250 for (
int w = 0; w < Base::Iterations::kW; ++w) {
251 for (
int c = 0; c < Base::Iterations::kC; ++c) {
252 bool flag = w * Base::Delta::kW +
thread_offset[2] + block_offset[2] < bounds[2];
256 (h * Base::Delta::kH + d * Base::Delta::kD) +
thread_offset[1] + block_offset[1] <
259 flag = flag && (h * Base::Delta::kH) +
thread_offset[1] + block_offset[1] < bounds[1];
313 for (
int d = 0; d < Base::Iterations::kD; ++d) {
314 for (
int h = 0; h < Base::Iterations::kH; ++h) {
315 for (
int w = 0; w < Base::Iterations::kW; ++w) {
316 for (
int c = 0; c < Base::Iterations::kC; ++c) {
319 offset += block_h + h * Base::Delta::kH + d * Base::Delta::kD;
321 offset += block_w + w * Base::Delta::kW;
343 long long _offset = offset.template dot<long long>(
361 template <
typename Fragment>
364 for (
int d = 0; d < Base::Iterations::kD; ++d) {
365 for (
int h = 0; h < Base::Iterations::kH; ++h) {
366 for (
int w = 0; w < Base::Iterations::kW; ++w) {
367 for (
int c = 0; c < Base::Iterations::kC; ++c) {
368 if (
valid(d, h, w, c)) {
370 reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
377 if (w < Base::Iterations::kW - 1) {
381 if (h < Base::Iterations::kH - 1) {
385 if (d < Base::Iterations::kD - 1) {
395 template <
typename TileTraits_,
typename Index_ =
int>
397 typename TileTraits_::Scalar,
399 MemorySpace::kGlobal,
405 typename TileTraits_::Scalar,
415 typedef typename TileTraits_::Scalar
Scalar;
417 typedef typename TileTraits_::Pointer
Pointer;
419 typedef typename TileTraits_::Threads
Threads;
442 long long batch_stride,
445 Index epilogue_stride_w,
446 Index epilogue_delta_w) {
452 stride_h = TileTraits_::ThreadsDelta::kH * ldm;
455 inc_h = ldm * TileTraits_::kStrideH;
457 (ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
462 -((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
483 for (
int i = 0; i < Base::Iterations::kW; ++i) {
505 for (
int i = 0; i < Base::Iterations::kW; ++i) {
530 long long _offset = offset.template dot<long long>(
574 template <
typename Fragment>
577 for (
int d = 0; d < Base::Iterations::kD; ++d) {
578 for (
int h = 0; h < Base::Iterations::kH; ++h) {
579 for (
int w = 0; w < Base::Iterations::kW; ++w) {
580 for (
int c = 0; c < Base::Iterations::kC; ++c) {
581 if (
valid(d, h, w, c)) {
583 reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
590 if (w < Base::Iterations::kW - 1) {
594 if (h < Base::Iterations::kH - 1) {
598 if (d < Base::Iterations::kD - 1) {
605 template <
typename Fragment>
608 for (
int d = 0; d < Base::Iterations::kD; ++d) {
609 for (
int h = 0; h < Base::Iterations::kH; ++h) {
610 for (
int w = 0; w < Base::Iterations::kW; ++w) {
611 for (
int c = 0; c < Base::Iterations::kC; ++c) {
612 if (
valid(d, h, w, c)) {
614 reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
621 if (w < Base::Iterations::kW - 1) {
625 if (h < Base::Iterations::kH - 1) {
629 if (d < Base::Iterations::kD - 1) {
Definition: gemm_global_tile.h:120
Shape< 0, Threads::kH, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:92
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:529
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:362
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:434
cutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount > PredicateVector
Definition: gemm_global_tile.h:196
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:180
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &threadblock_offset, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:270
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:461
Base::Params BaseParams
Iterator parameters type.
Definition: gemm_global_tile.h:199
CUTLASS_HOST_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: gemm_global_tile.h:512
Index_ Index
The index.
Definition: gemm_global_tile.h:421
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
ReshapeTile< Tile_, kAccessSize_ >::Tile VectorizedTile
The vectorized tile shape.
Definition: gemm_global_tile.h:86
GemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:402
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:412
Definition: gemm_global_tile.h:70
Scalar_ * Pointer
The pointer.
Definition: gemm_global_tile.h:78
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:199
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:357
Definition: load_store.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
Shape< 1, 1, VectorizedTile::kC > ThreadsDelta
The relative offset between two elements in the H/W dimension in adjacent threads.
Definition: gemm_global_tile.h:90
GemmMultiplicandTraits< Tile, kOperand, kLayout > MultiplicandTraits
Definition: gemm_global_tile.h:103
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:425
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_global_tile.h:82
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:292
TileIteratorBase< TileTraits_, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:409
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:196
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:138
Index predicate_inc_h
Definition: gemm_global_tile.h:436
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:584
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:434
long long inc_d
Definition: tile_iterator.h:223
Tile_ Tile
The tile shape.
Definition: gemm_global_tile.h:84
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:464
Definition: tile_iterator.h:65
long long inc_advance
Definition: tile_iterator.h:227
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads and increments iterator.
Definition: gemm_global_tile.h:575
TileLoadIterator< TileTraits_, typename TileTraits_::Scalar, TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:178
Definition: gemm_global_tile.h:201
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block, int offset=0, int pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:489
Definition: matrix_traits.h:357
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:181
Definition: gemm_global_tile.h:163
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: gemm_global_tile.h:133
Kind
Definition: load_store.h:39
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:287
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: gemm_global_tile.h:194
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:188
static int const kStrideH
The stride in the H dimension.
Definition: gemm_global_tile.h:136
static int const kH
The height of the cube.
Definition: shape.h:68
Definition: load_store.h:178
Shape< Threads_::kD, Threads_::kH *Threads_::kW/Tile_::kW, Tile_::kW, 1 > Threads
Definition: gemm_global_tile.h:59
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:436
static GemmOperand::Kind const kOperand
Identity of the operand.
Definition: gemm_global_tile.h:72
Index stride_h
Definition: tile_iterator.h:220
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Test the validity of the.
Definition: gemm_global_tile.h:566
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:521
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:399
Definition: gemm_global_tile.h:58
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long batch_stride, Index ldm, Index bound, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: gemm_global_tile.h:441
PredicateVector predicates
The predicates.
Definition: gemm_global_tile.h:241
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_global_tile.h:76
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:285
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:341
Base::Fragment Fragment
Fragment type loaded by the iterator.
Definition: gemm_global_tile.h:184
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:419
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:151
Definition: gemm_operand.h:67
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:106
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:283
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:771
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:344
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:606
Base::Threads Threads
Definition: gemm_global_tile.h:142
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:432
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:289
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:108
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:538
Index inc_h
Definition: tile_iterator.h:224
Shape< 0, 0, Threads::kW *ThreadsDelta::kW, kAccessSize > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: gemm_global_tile.h:95
Statically sized array of bits implementing.
Definition: predicate_vector.h:105
Definition: load_store.h:60
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:423
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Base::ImmediateOffsetStrides ImmediateOffsetStrides
Definition: gemm_global_tile.h:146
long long stride_d
Definition: tile_iterator.h:219
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:415
Index inc_h
Definition: gemm_global_tile.h:434
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: gemm_global_tile.h:473
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:428
GemmGlobalIteratorAb< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:171
TileTraits_::Tile Tile
The tile.
Definition: gemm_global_tile.h:182
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
Shape< 1, VectorizedTile::kH/Threads::kH, VectorizedTile::kW/Threads::kW, VectorizedTile::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_global_tile.h:101
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block_offset, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:476
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:523
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:516
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:514
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:493
Params params
Parameters.
Definition: gemm_global_tile.h:469
Definition: gemm_global_tile.h:396
Definition: matrix_traits.h:159
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:428
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:471
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:192
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:365
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Definition: gemm_global_tile.h:351
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:185
Parameters.
Definition: tile_iterator.h:491
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:149
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
long long stride_d
The stride in the D dimension.
Definition: gemm_global_tile.h:430
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: gemm_global_tile.h:353
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:680
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_global_tile.h:80
Tile_ Tile
Definition: reshape_tile.h:43
Definition: tile_iterator.h:65
Base::Iterations Iterations
Definition: gemm_global_tile.h:140
Index_ Index
The index.
Definition: gemm_global_tile.h:190
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:417
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the valid?
Definition: gemm_global_tile.h:335
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:431
Kind
Definition: matrix_traits.h:357
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:186
Threads_ Threads
Definition: gemm_global_tile.h:54
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, long long stride_d, Index stride_h)
Initializes params to load a strip-mined tile, given pointer and stride_h.
Definition: gemm_global_tile.h:203
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
add pointer offset
Definition: gemm_global_tile.h:571
Params params
The parameters.
Definition: gemm_global_tile.h:239
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &block_offset)
Definition: gemm_global_tile.h:243
The params.
Definition: gemm_global_tile.h:426
Base::ThreadsDelta ThreadsDelta
Definition: gemm_global_tile.h:144
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:237
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:467
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
CUTLASS_HOST_DEVICE void residue(Index k)
That's the residue! Update the predicates.
Definition: gemm_global_tile.h:306
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:438
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:74
CUTLASS_HOST_DEVICE void store_element(typename Base::AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: gemm_global_tile.h:552
Index stride_w
Definition: tile_iterator.h:221