37 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kScalarsPerSts_>
58 Tile::kH / Threads::kH,
59 Tile::kW / Threads::kW,
79 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kScalarsPerSts_,
int kSkew_>
93 static int const kSkew = kSkew_;
100 typedef Shape<1, TileWithoutSkew::kH / Threads::kW, TileWithoutSkew::kW / Threads::kH>
Iterations;
121 template <
typename Scalar_,
122 typename OutputTile_,
124 typename ThreadsPerWarp_,
125 typename InstructionShape_,
136 typedef Shape<kStages_,
137 OutputTile_::kD / InstructionShape_::kD,
177 int const warp = threadIdx.x / kWarpSize % Warps::kW;
179 int const lane = (threadIdx.x & 0x0e) / 2;
181 int const offset = (warp * ThreadsPerWarp::kW + lane) *
kAccessSize;
190 template <
typename Scalar_,
191 typename OutputTile_,
193 typename ThreadsPerWarp_,
194 typename InstructionShape_,
205 typedef Shape<kStages_,
206 OutputTile_::kD / InstructionShape_::kD,
244 int const warp = threadIdx.x / (Warps::kW * kWarpSize);
247 int const lane = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
249 int const offset = (warp * ThreadsPerWarp::kH + lane) *
kAccessSize;
258 template <
typename Scalar_,
259 typename OutputTile_,
261 typename ThreadsPerWarp_,
307 int const row = threadIdx.x & 0x01;
309 int const warp_id = (threadIdx.x >> 5);
311 int const warp_row = (warp_id % Warps::kW);
312 int const warp_col = (warp_id / Warps::kW);
314 int hi_halfwarp_offset = OutputTile::kW * ((threadIdx.x >> 4) & 1);
315 int lo_halfwarp_offset = (((threadIdx.x >> 1) & 0x7) + warp_row * ThreadsPerWarp::kW);
318 warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW + hi_halfwarp_offset;
328 template <
typename Scalar_,
329 typename OutputTile_,
331 typename ThreadsPerWarp_,
387 int const h = threadIdx.x / kWarpSize;
389 int const w = (threadIdx.x & (kWarpSize - 1)) *
kAccessSize;
392 int const row = h & 0x1;
393 int const col = h / 2;
static int const kAccessSize
The number of scalars per STS.
Definition: gemm_shared_tile.h:95
static CUTLASS_DEVICE int get()
Definition: shape.h:253
ReshapeTile< TileWithSkew, kScalarsPerLds_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:214
ReshapeTile< TileWithSkew, kScalarsPerLds_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:145
ReshapeTile< TileWithoutSkew_, kScalarsPerLds_ >::Tile TileWithoutSkew
The tile without skew after reshaping.
Definition: gemm_shared_tile.h:212
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:351
static int const kScalarsPerThread
The number of scalars per thread.
Definition: gemm_shared_tile.h:354
Definition: load_store.h:42
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:242
Shape< 1, 1, TileWithoutSkew::kW/kWarps/kThreadsPerWarp > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:232
static int const kWarps
The number of warps.
Definition: gemm_shared_tile.h:227
Definition: gemm_shared_tile.h:129
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:42
Definition: gemm_shared_tile.h:80
static int const kScalarsPerRow
The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
Definition: gemm_shared_tile.h:287
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:132
Definition: gemm_shared_tile.h:106
Shape< 1, 1, kScalarsPerThread/kAccessSize > Iterations
The number of iterations needed to store the tile.
Definition: gemm_shared_tile.h:292
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:347
ThreadsPerWarp_ ThreadsPerWarp
The threads in a warp.
Definition: gemm_shared_tile.h:149
Definition: reshape_tile.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Shape< 0, ShapeCount< Tile >::kWc, Tile::kC, kScalarsPerSts_ > ThreadsStrides
The strides to compute the base position of the thread.
Definition: gemm_shared_tile.h:48
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:276
Shape< kIterationsD, kIterationsH, OutputTile::kW/kWarpSize/kAccessSize > Iterations
The number of iterations needed to store the tile.
Definition: gemm_shared_tile.h:376
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:349
Warps_ Warps
The number of warps.
Definition: gemm_shared_tile.h:216
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:222
Definition: gemm_shared_tile.h:38
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:201
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:383
Definition: gemm_shared_tile.h:198
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:156
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:40
static GemmOperand::Kind const kOperand
Definition: gemm_shared_tile.h:130
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:280
Kind
Definition: load_store.h:40
Shape< kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW+kSkew_ > TileWithSkew
The tile with skew.
Definition: gemm_shared_tile.h:210
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:152
static int const kH
The height of the cube.
Definition: shape.h:68
Shape< 1, Tile::kH/Threads::kH, Tile::kW/Threads::kW, Tile::kC/Threads::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:61
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:93
Shape< 1, 1, TileWithoutSkew::kW/kWarps/kThreadsPerWarp > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:165
OutputTile_ OutputTile
The dimension of the output tile.
Definition: gemm_shared_tile.h:270
static int const kScalarsPerRow
The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
Definition: gemm_shared_tile.h:358
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:203
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:134
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:268
static int const kScalarsPerThread
The number of scalars per thread.
Definition: gemm_shared_tile.h:283
Shape< OutputTile::kW, kScalarsPerRow, kWarpSize *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:380
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:301
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:54
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:50
static int const kThreadsPerWarp
The number of threads in one dimension of the warp.
Definition: gemm_shared_tile.h:229
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:240
Shape< 0, ShapeCount< Tile >::kWc, Threads::kH *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:104
Shape< 1, 2, kScalarsPerRow/kAccessSize, kAccessSize > Tile
The tile.
Definition: gemm_shared_tile.h:290
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:52
ReshapeTile< Tile_, kScalarsPerSts_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:44
Definition: gemm_shared_tile.h:68
static int const kIterationsInHPerWarp
Definition: gemm_shared_tile.h:364
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:278
ReshapeTile< TileWithoutSkew_, kScalarsPerLds_ >::Tile TileWithoutSkew
The tile without skew after reshaping.
Definition: gemm_shared_tile.h:143
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Shape< 0, Threads::kH *ShapeCount< Tile >::kWc, Threads::kW *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:66
Shape< TileWithSkew::kW, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:170
ReshapeTile< Shape< Tile_::kD, Tile_::kH, Tile_::kW+kSkew_ >, kScalarsPerSts_ >::Tile Tile
The tile.
Definition: gemm_shared_tile.h:89
Shape< 0, kScalarsPerSts_, ShapeCount< Tile >::kHwc/Threads::kW > ThreadsStrides
The strides to compute the base position of the thread.
Definition: gemm_shared_tile.h:116
ReshapeTile< Tile_, kScalarsPerSts_ >::Tile TileWithoutSkew
The tile without skews.
Definition: gemm_shared_tile.h:86
static int const kIterationsD
Definition: gemm_shared_tile.h:373
static int const kWarps
The number of warps.
Definition: gemm_shared_tile.h:159
Definition: matrix_traits.h:43
ThreadsPerWarp_ ThreadsPerWarp
The threads in the warps.
Definition: gemm_shared_tile.h:274
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:173
Shape< kStages_, OutputTile_::kD/InstructionShape_::kD, GetExtent< kOperand, OutputTile_ >::kExtent *InstructionShape_::kD > TileWithoutSkew_
The tile without skew.
Definition: gemm_shared_tile.h:139
Definition: gemm_shared_tile.h:335
Threads_ Threads
The threads.
Definition: gemm_shared_tile.h:91
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
OutputTile_ OutputTile
The dimension of the output tile.
Definition: gemm_shared_tile.h:341
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:82
Shape< TileWithSkew::kW, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:167
Shape< 0, 0, Warps::kW *ThreadsPerWarp::kW *kAccessSize > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:296
static GemmOperand::Kind const kOperand
Definition: gemm_shared_tile.h:199
Shape< 1, 2, kScalarsPerRow/kAccessSize, kAccessSize > Tile
The tile.
Definition: gemm_shared_tile.h:361
static int const kThreadsPerWarp
The number of threads in one dimension of the warp.
Definition: gemm_shared_tile.h:161
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:84
Shape< OutputTile::kW, kScalarsPerRow, kWarpSize *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:378
Shape< 1, TileWithoutSkew::kH/Threads::kW, TileWithoutSkew::kW/Threads::kH > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_shared_tile.h:100
Shape< kStages_, OutputTile_::kD/InstructionShape_::kD, GetExtent< kOperand, OutputTile_ >::kExtent *InstructionShape_::kD > TileWithoutSkew_
The tile without skew.
Definition: gemm_shared_tile.h:208
Threads_ Threads
The threads.
Definition: gemm_shared_tile.h:46
Definition: gemm_operand.h:50
Shape< 0, Threads::kH *ShapeCount< Tile >::kWc, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:63
static int const kThreads
The number of threads.
Definition: gemm_shared_tile.h:356
Warps_ Warps
The number of warps.
Definition: gemm_shared_tile.h:147
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:97
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_shared_tile.h:224
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:70
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:175
static int const kD
The depth of the cube.
Definition: shape.h:66
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_shared_tile.h:299
Warps_ Warps
The warps in the tile.
Definition: gemm_shared_tile.h:343
Tile_ Tile
Definition: reshape_tile.h:43
Shape< 0, ShapeCount< Tile >::kWc, Threads::kH *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:102
static int const kIterationsH
Definition: gemm_shared_tile.h:371
Shape< 0, 0, Warps::kW *ThreadsPerWarp::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:294
Kind
Definition: matrix_traits.h:43
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:154
ThreadsPerWarp_ ThreadsPerWarp
The threads in the warps.
Definition: gemm_shared_tile.h:345
Definition: matrix_traits.h:43
Scalar_ * Pointer
The pointer.
Definition: gemm_shared_tile.h:339
static int const kThreads
The number of threads.
Definition: gemm_shared_tile.h:285
Shape< TileWithSkew::kW, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:234
ThreadsPerWarp_ ThreadsPerWarp
The threads in a warp.
Definition: gemm_shared_tile.h:218
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:266
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Shape< kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW+kSkew_ > TileWithSkew
The tile with skew.
Definition: gemm_shared_tile.h:141
Warps_ Warps
The warps in the tile.
Definition: gemm_shared_tile.h:272
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:108
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:337
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_shared_tile.h:385
Shape< TileWithSkew::kW, 0, kWarps *kThreadsPerWarp *kAccessSize, 0 > ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: gemm_shared_tile.h:237
Definition: gemm_shared_tile.h:264
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_shared_tile.h:220