37 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kAccessSize_>
39 MatrixLayout::kColumnMajor,
60 int thread_offset_h = threadIdx.x / Base::Threads::kW;
63 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
70 template <
typename TileTraits_,
typename Index_ =
int>
72 typename TileTraits_::Scalar,
82 typename TileTraits_::Scalar,
93 typedef typename TileTraits_::Scalar
Scalar;
95 typedef typename TileTraits_::Pointer
Pointer;
97 typedef typename TileTraits_::Threads
Threads;
124 inc_h = ld * TileTraits_::Threads::kH;
147 int const pointer_offset = 0,
148 int const pred_offset = 0,
161 for (
int i = 0; i < Base::Iterations::kW; ++i) {
185 CUTLASS_DEVICE
bool valid(
int d,
int h,
int w,
int c)
const {
TileTraits_::Threads Threads
The threads.
Definition: wmma_gemm_global_tile.h:97
Defines iterators for efficiently loading and storing to global memory.
Definition: gemm_global_tile.h:70
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:356
CUTLASS_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: wmma_gemm_global_tile.h:177
Definition: load_store.h:43
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: wmma_gemm_global_tile.h:108
CUTLASS_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: wmma_gemm_global_tile.h:170
Index_ Index
The index.
Definition: wmma_gemm_global_tile.h:99
TileTraits_::Scalar Scalar
The scalar.
Definition: wmma_gemm_global_tile.h:93
Definition: tile_iterator.h:62
Definition: matrix_traits.h:43
Params params
Definition: wmma_gemm_global_tile.h:136
Index predicate_inc_h
The strides to increment the predicate offset.
Definition: wmma_gemm_global_tile.h:114
Pointer pointer
The pointer.
Definition: wmma_gemm_global_tile.h:106
CUTLASS_HOST_DEVICE Pointer const data() const
Definition: wmma_gemm_global_tile.h:194
CUTLASS_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: wmma_gemm_global_tile.h:179
The params.
Definition: wmma_gemm_global_tile.h:104
Index inc_h
The strides to increment the pointer.
Definition: wmma_gemm_global_tile.h:110
TileIteratorBase< Traits, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:86
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd()
Ctor.
Definition: wmma_gemm_global_tile.h:141
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: wmma_gemm_global_tile.h:112
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: wmma_gemm_global_tile.h:59
Definition: wmma_gemm_global_tile.h:71
Index predicate_inc_advance
Definition: wmma_gemm_global_tile.h:114
TileTraits_::Pointer Pointer
The pointer.
Definition: wmma_gemm_global_tile.h:95
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:54
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Coord< 4 > thread_offset
Definition: wmma_gemm_global_tile.h:138
Index inc_advance
Definition: wmma_gemm_global_tile.h:110
static MatrixLayout::Kind const kLayout
The layout.
Definition: wmma_gemm_global_tile.h:90
Definition: wmma_gemm_global_tile.h:38
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:102
TileTraits_::ThreadOffset ThreadOffset
The thread offset functor.
Definition: wmma_gemm_global_tile.h:101
Definition: matrix_traits.h:36
CUTLASS_HOST_DEVICE Pointer data()
Returns the raw pointer.
Definition: wmma_gemm_global_tile.h:191
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:364
Kind
Definition: matrix_traits.h:36
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:51
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const ¶ms, const Coord< 3 > &bounds, const Coord< 3 > &block, int const pointer_offset=0, int const pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: wmma_gemm_global_tile.h:144
WmmaGemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: wmma_gemm_global_tile.h:77
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: wmma_gemm_global_tile.h:197
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > ImmediateOffsetStrides
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:88
Computes the thread offset in (H, W) based on thread ID.
Definition: wmma_gemm_global_tile.h:57
CUTLASS_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: wmma_gemm_global_tile.h:168
CUTLASS_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: wmma_gemm_global_tile.h:172
TileTraits_ Traits
The traits.
Definition: wmma_gemm_global_tile.h:79
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Test the predicate.
Definition: wmma_gemm_global_tile.h:185
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld, Index n, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: wmma_gemm_global_tile.h:117