32 #ifdef CUTLASS_USE_WMMA_API 53 struct WmmaGemmSharedLoadTileATraits {
59 typedef Scalar_ Scalar;
61 typedef Scalar
const* Pointer;
63 static int const kAccessSize = 1;
69 static int const kWarpStride = kWarpStride_;
71 typedef Iterations_ Iterations;
75 typedef Delta_ ImmediateOffsetStrides;
77 typedef WmmaShape_ WmmaShape;
83 Coord<4> operator()()
const {
85 int const warp = threadIdx.x / kWarpSize;
87 int const offset = warp % Warps::kW * kWarpStride;
100 typename Iterations_,
103 struct WmmaGemmSharedLoadTileBTraits {
109 typedef Scalar_ Scalar;
111 typedef Scalar
const* Pointer;
113 static int const kAccessSize = 1;
117 typedef Warps_ Warps;
119 static int const kWarpStride = kWarpStride_;
121 typedef Iterations_ Iterations;
123 typedef Delta_ Delta;
125 typedef Delta_ ImmediateOffsetStrides;
127 typedef WmmaShape_ WmmaShape;
131 struct ThreadOffset {
133 Coord<4> operator()()
const {
135 int const warp = threadIdx.x / kWarpSize;
137 int const offset = warp / Warps::kW * kWarpStride;
147 typename OutputTile_,
151 struct WmmaGemmSharedStoreTileDTraits {
157 typedef Scalar_ Scalar;
159 static int const kAccessSize = 1;
161 typedef Scalar* Pointer;
163 typedef Warps_ Warps;
165 typedef WmmaShape_ WmmaShape;
167 static int const kSkew = kSkew_;
171 typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
173 typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
175 typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
177 typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
180 struct ThreadOffset {
182 Coord<4> operator()()
const {
184 int const warp = threadIdx.x / kWarpSize;
186 int const h = warp / Warps::kW * WmmaShape::kH;
188 int const w = warp % Warps::kW * WmmaShape::kW;
190 int const offset = h * Tile::kW + w;
198 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kScalarsPerLds_>
199 struct WmmaGemmSharedLoadTileDTraits {
201 typedef Scalar_ Scalar;
203 typedef Scalar
const* Pointer;
205 static int const kAccessSize = kScalarsPerLds_;
211 typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
216 typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
218 typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_>
219 ImmediateOffsetStrides;
221 typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
225 struct ThreadOffset {
227 Coord<4> operator()()
const {
240 #endif // defined CUTLASS_USE_WMMA_API static CUTLASS_DEVICE int get()
Definition: shape.h:253
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Definition: matrix_traits.h:43
Kind
Definition: load_store.h:40
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Definition: matrix_traits.h:43
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Kind
Definition: matrix_traits.h:36
Tile_ Tile
Definition: reshape_tile.h:43
Kind
Definition: matrix_traits.h:43
Definition: matrix_traits.h:43
Threads_ Threads
Definition: gemm_global_tile.h:54