52 template <
typename Tile_,
typename Threads_,
bool = (Tile_::kW < Threads_::kW)>
53 struct ReshapeThreads {
54 typedef Threads_ Threads;
57 template <
typename Tile_,
typename Threads_>
59 typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1>
Threads;
96 typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC /
kAccessSize>
108 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
115 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kStr
ideH_,
int kAccessSize_>
117 MatrixLayout::kColumnMajor,
151 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
158 template <
typename TileTraits_,
typename Index_ =
int>
161 typename TileTraits_::Scalar,
162 TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
163 : IteratorAdvance::kW,
164 MemorySpace::kGlobal,
170 typename TileTraits_::Scalar,
181 typedef typename TileTraits_::Scalar
Scalar;
183 typedef typename TileTraits_::Threads
Threads;
205 if (Base::Delta::kD > 0) {
214 }
else if (Base::Delta::kD > 0) {
216 (Base::Iterations::kH - 1) *
inc_h -
217 (Base::Iterations::kD - 1) * Base::Delta::kD *
stride_h;
220 (Base::Iterations::kH - 1) *
inc_h;
237 int bounds_h, bounds_w;
239 bounds_w = bounds[2] - block[2];
240 bounds_h = bounds[1];
243 bounds_w = bounds[1];
244 bounds_h = bounds[2] - block[1];
248 for (
int d = 0; d < Base::Iterations::kD; ++d) {
249 for (
int h = 0; h < Base::Iterations::kH; ++h) {
250 for (
int w = 0; w < Base::Iterations::kW; ++w) {
251 for (
int c = 0; c < Base::Iterations::kC; ++c) {
252 bool flag = w * Base::Delta::kW < bounds_w;
254 flag = flag && (h * Base::Delta::kH + d * Base::Delta::kD) < bounds_h;
256 flag = flag && (h * Base::Delta::kH) < bounds_h;
314 for (
int d = 0; d < Base::Iterations::kD; ++d) {
315 for (
int h = 0; h < Base::Iterations::kH; ++h) {
316 for (
int w = 0; w < Base::Iterations::kW; ++w) {
317 for (
int c = 0; c < Base::Iterations::kC; ++c) {
320 offset += block_h + h * Base::Delta::kH + d * Base::Delta::kD;
322 offset += block_w + w * Base::Delta::kW;
336 CUTLASS_DEVICE
bool valid(
int d,
int h,
int w,
int c)
const {
347 template <
typename TileTraits_,
typename Index_ =
int>
349 typename TileTraits_::Scalar,
351 MemorySpace::kGlobal,
357 typename TileTraits_::Scalar,
367 typedef typename TileTraits_::Scalar
Scalar;
369 typedef typename TileTraits_::Pointer
Pointer;
371 typedef typename TileTraits_::Threads
Threads;
396 stride_h = TileTraits_::ThreadsDelta::kH * ld;
399 inc_h = ld * TileTraits_::kStrideH;
401 (ld - ld * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
406 -((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
436 for (
int i = 0; i < Base::Iterations::kW; ++i) {
460 CUTLASS_DEVICE
bool valid(
int d,
int h,
int w,
int c)
const {
Definition: gemm_global_tile.h:116
Shape< 0, Threads::kH, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:92
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:384
CUTLASS_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:452
cutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount > PredicateVector
Definition: gemm_global_tile.h:191
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:177
Base::Params BaseParams
Iterator parameters type.
Definition: gemm_global_tile.h:194
Shape< 1, Tile::kH/Threads::kH, Tile::kW/Threads::kW, Tile::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_global_tile.h:97
Index_ Index
The index.
Definition: gemm_global_tile.h:373
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
GemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:354
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:364
Definition: gemm_global_tile.h:70
Scalar_ * Pointer
The pointer.
Definition: gemm_global_tile.h:78
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:62
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:356
Definition: load_store.h:43
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
GemmMultiplicandTraits< Tile, kOperand, kLayout > MultiplicandTraits
Definition: gemm_global_tile.h:99
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_global_tile.h:82
TileIteratorBase< TileTraits_, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:361
Shape< 1, 1, Tile::kC > ThreadsDelta
The relative offset between two elements in the H/W dimension in adjacent threads.
Definition: gemm_global_tile.h:89
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:134
Index predicate_inc_h
Definition: gemm_global_tile.h:386
static CUTLASS_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:166
CUTLASS_HOST_DEVICE Pointer const data() const
Definition: gemm_global_tile.h:469
CUTLASS_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &block)
Definition: gemm_global_tile.h:233
Definition: tile_iterator.h:62
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:331
TileLoadIterator< TileTraits_, typename TileTraits_::Scalar, TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:175
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: gemm_global_tile.h:336
Definition: gemm_global_tile.h:196
Definition: matrix_traits.h:43
Definition: gemm_global_tile.h:159
CUTLASS_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:454
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: gemm_global_tile.h:129
Kind
Definition: load_store.h:40
Index stride_h
Definition: tile_iterator.h:172
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: gemm_global_tile.h:189
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:183
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_h)
Initializes params to load a strip-mined tile, given pointer and stride_h.
Definition: gemm_global_tile.h:198
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:425
static int const kStrideH
The stride in the H dimension.
Definition: gemm_global_tile.h:132
static int const kH
The height of the cube.
Definition: shape.h:68
Shape< Threads_::kD, Threads_::kH *Threads_::kW/Tile_::kW, Tile_::kW, 1 > Threads
Definition: gemm_global_tile.h:59
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:386
static GemmOperand::Kind const kOperand
Identity of the operand.
Definition: gemm_global_tile.h:72
Index inc_h
Definition: tile_iterator.h:176
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:302
Definition: gemm_global_tile.h:58
PredicateVector predicates
The predicates.
Definition: gemm_global_tile.h:342
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_global_tile.h:76
CUTLASS_HOST_DEVICE Scalar const * data() const
Returns the current pointer.
Definition: gemm_global_tile.h:304
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Base::Fragment Fragment
Fragment type loaded by the iterator.
Definition: gemm_global_tile.h:179
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:371
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:147
CUTLASS_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:447
CUTLASS_DEVICE GemmGlobalIteratorCd(Params const ¶ms, const Coord< 3 > &bounds, const Coord< 3 > &block, int offset=0, int pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:420
Definition: gemm_operand.h:67
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:102
Index inc_advance
Definition: tile_iterator.h:179
CUTLASS_DEVICE void residue(Index k)
That's the residue! Update the predicates.
Definition: gemm_global_tile.h:307
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:343
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:267
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld, Index bound, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: gemm_global_tile.h:391
CUTLASS_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: gemm_global_tile.h:443
CUTLASS_HOST_DEVICE Pointer data()
Returns the raw pointer.
Definition: gemm_global_tile.h:466
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:390
Base::Threads Threads
Definition: gemm_global_tile.h:138
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:382
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:104
Shape< 0, 0, Threads::kW *ThreadsDelta::kW, kAccessSize > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: gemm_global_tile.h:94
Statically sized array of bits implementing.
Definition: predicate_vector.h:104
CUTLASS_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:296
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:375
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Base::ImmediateOffsetStrides ImmediateOffsetStrides
Definition: gemm_global_tile.h:142
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:367
Index inc_h
Definition: gemm_global_tile.h:384
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: gemm_global_tile.h:472
CUTLASS_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:298
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:380
GemmGlobalIteratorAb< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:167
ReshapeTile< Tile_, kAccessSize_ >::Tile Tile
The tile shape.
Definition: gemm_global_tile.h:85
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:364
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:102
CUTLASS_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:445
Params params
Definition: gemm_global_tile.h:412
Definition: gemm_global_tile.h:348
Definition: matrix_traits.h:36
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:414
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:187
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:364
Parameters.
Definition: tile_iterator.h:388
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:145
Kind
Definition: matrix_traits.h:36
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_global_tile.h:80
Tile_ Tile
Definition: reshape_tile.h:43
Definition: tile_iterator.h:62
Base::Iterations Iterations
Definition: gemm_global_tile.h:136
Index_ Index
The index.
Definition: gemm_global_tile.h:185
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:369
Kind
Definition: matrix_traits.h:43
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:181
Threads_ Threads
Definition: gemm_global_tile.h:54
ReshapeThreads< Tile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:87
CUTLASS_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:300
CUTLASS_DEVICE GemmGlobalIteratorCd()
Ctor.
Definition: gemm_global_tile.h:417
Params params
The parameters.
Definition: gemm_global_tile.h:231
Defines properties of matrices used to denote layout and operands to GEMM kernels.
The params.
Definition: gemm_global_tile.h:378
Base::ThreadsDelta ThreadsDelta
Definition: gemm_global_tile.h:140
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Test the validity of the iterator.
Definition: gemm_global_tile.h:460
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:229
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:388
Index inc_d
Definition: tile_iterator.h:175
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:74