cutlass/cutlass/wmma_matrix.h
2018-09-18 16:58:03 -07:00

237 lines
8.7 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Abstractions for loading and storing matrices using the CUDA WMMA API.
*/
#pragma once
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
#define CUTLASS_USE_WMMA_API
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
#define CUTLASS_USE_SUBBYTE_WMMA
#endif
#include "stdio.h"
#if __CUDACC_VER_MAJOR__ >= 10
#include <mma.h>
#else
#include <crt/mma.h>
#endif
#include "cutlass/fragment.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/shape.h"
#include "cutlass/vector.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
template <MatrixLayout::Kind kLayout_>
struct WmmaLayout {
typedef nvcuda::wmma::col_major Layout;
};
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
template <>
struct WmmaLayout<MatrixLayout::kRowMajor> {
typedef nvcuda::wmma::row_major Layout;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps cutlass types to nvcuda::wmma datatypes
template <typename Type_>
struct WmmaDataType{
typedef Type_ Type;
};
#ifdef CUTLASS_USE_SUBBYTE_WMMA
/// Statically maps cutlass::Vector<bin1_t, 32> to nvcuda::wmma::experimental::precision::b1
template<>
struct WmmaDataType<Vector<bin1_t, 32> > {
typedef nvcuda::wmma::experimental::precision::b1 Type;
};
/// Statically maps cutlass::Vector<int4_t, 8> to nvcuda::wmma::experimental::precision::s4
template<>
struct WmmaDataType<Vector<int4_t, 8> > {
typedef nvcuda::wmma::experimental::precision::s4 Type;
};
/// Statically maps cutlass::Vector<uint4_t, 8> to nvcuda::wmma::experimental::precision::u4
template<>
struct WmmaDataType<Vector<uint4_t, 8> > {
typedef nvcuda::wmma::experimental::precision::u4 Type;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment load and store operations
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
struct WmmaMatrix {};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment accessors for A operand
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
: public nvcuda::wmma::fragment<
/// The nvcuda::wmma operand name.
nvcuda::wmma::matrix_a,
/// The dimensions.
WmmaShape_::kW,
WmmaShape_::kH,
WmmaShape_::kD,
/// The scalar.
typename WmmaDataType<Scalar_>::Type,
/// The layout.
typename WmmaLayout<kLayout_>::Layout> {
/// This type.
typedef WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_> This_;
/// Fill-in the element.
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
nvcuda::wmma::fill_fragment(*this, x);
return *this;
}
/// Load from memory.
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
}
/// Store to memory.
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment accessors for B operand
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
: public nvcuda::wmma::fragment<
/// The nvcuda::wmma operand name.
nvcuda::wmma::matrix_b,
/// The dimensions.
WmmaShape_::kW,
WmmaShape_::kH,
WmmaShape_::kD,
/// The scalar.
typename WmmaDataType<Scalar_>::Type,
/// The layout.
typename WmmaLayout<kLayout_>::Layout> {
/// This type.
typedef WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_> This_;
/// Fill-in the element.
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
nvcuda::wmma::fill_fragment(*this, x);
return *this;
}
/// Load from memory.
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
}
/// Store to memory.
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment accessors for C operand
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
: public nvcuda::wmma::fragment<
/// The nvcuda::wmma operand name.
nvcuda::wmma::accumulator,
/// The dimensions.
WmmaShape_::kW,
WmmaShape_::kH,
WmmaShape_::kD,
/// The scalar.
Scalar_> {
/// This type.
typedef WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_> This_;
/// The layout.
static MatrixLayout::Kind const kLayout = kLayout_;
/// Fill-in the element.
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
nvcuda::wmma::fill_fragment(*this, x);
return *this;
}
/// Load from memory.
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
nvcuda::wmma::load_matrix_sync(
*this,
pointer,
stride,
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
}
/// Store to memory.
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
nvcuda::wmma::store_matrix_sync(
pointer,
*this,
stride,
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// WmmaMatrix cannot be used in a Union and thus in cannot be used in our Vector implementation.
// The only use of WmmaMatrix in in combination with Vectorize has kLanes == 1. Due to this it is
// safe to keep the Vector->Scalar conversion for WmmaMatrix.
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
struct Vectorize<WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_>, 1> {
typedef WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_> Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
#endif // defined CUTLASS_USE_WMMA_API