30 #if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700) 35 #define CUTLASS_USE_WMMA_API 51 template <MatrixLayout::Kind kLayout_>
53 typedef nvcuda::wmma::col_major Layout;
58 struct WmmaLayout<MatrixLayout::kRowMajor> {
59 typedef nvcuda::wmma::row_major Layout;
74 template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
75 struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
76 :
public nvcuda::wmma::fragment<
78 nvcuda::wmma::matrix_a,
86 typename WmmaLayout<kLayout_>::Layout> {
88 typedef WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_> This_;
91 CUTLASS_DEVICE This_& operator=(Scalar_
const& x) {
92 nvcuda::wmma::fill_fragment(*
this, x);
97 CUTLASS_DEVICE
void load(Scalar_
const* pointer,
int const stride) {
98 nvcuda::wmma::load_matrix_sync(*
this, pointer, stride);
102 CUTLASS_DEVICE
void store(Scalar_* pointer,
int const stride)
const {
103 nvcuda::wmma::store_matrix_sync(pointer, *
this, stride);
110 template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
111 struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
112 :
public nvcuda::wmma::fragment<
114 nvcuda::wmma::matrix_b,
122 typename WmmaLayout<kLayout_>::Layout> {
124 typedef WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_> This_;
127 CUTLASS_DEVICE This_& operator=(Scalar_
const& x) {
128 nvcuda::wmma::fill_fragment(*
this, x);
133 CUTLASS_DEVICE
void load(Scalar_
const* pointer,
int const stride) {
134 nvcuda::wmma::load_matrix_sync(*
this, pointer, stride);
138 CUTLASS_DEVICE
void store(Scalar_* pointer,
int const stride)
const {
139 nvcuda::wmma::store_matrix_sync(pointer, *
this, stride);
146 template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
147 struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
148 :
public nvcuda::wmma::fragment<
150 nvcuda::wmma::accumulator,
158 typedef WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_> This_;
163 CUTLASS_DEVICE This_& operator=(Scalar_
const& x) {
164 nvcuda::wmma::fill_fragment(*
this, x);
169 CUTLASS_DEVICE
void load(Scalar_
const* pointer,
int const stride) {
171 nvcuda::wmma::load_matrix_sync(
175 kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
179 CUTLASS_DEVICE
void store(Scalar_* pointer,
int const stride)
const {
181 nvcuda::wmma::store_matrix_sync(
185 kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
193 #endif // defined CUTLASS_USE_WMMA_API
Definition: matrix_traits.h:36
Defines abstractions for efficiently loading and storing vectors to memory.
Defines a 1D vector of elements held in the registers of each thread.
Kind
Definition: matrix_traits.h:36
Kind
Definition: matrix_traits.h:43
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...