Performance enhancement for Volta Tensor Cores TN layout (#53)
* Fixed performance defect with indirect access to pointer array for Volta TensorCores TN arrangement. * Updated patch version and changelog. * Updated patch version and changelog. * Added link to changelog in readme. * Fixed markdown link
This commit is contained in:
parent
eb41735933
commit
b5cab177a9
@ -1,5 +1,8 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [1.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.2) (2019-07-09)
|
||||
* Performance improvement for Volta Tensor Cores TN and TT layouts.
|
||||
|
||||
## [1.3.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.1) (2019-04-09)
|
||||
* Corrected NVRTC unit tests.
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
# CUTLASS 1.3
|
||||
|
||||
_CUTLASS 1.3.1 - April 2019_
|
||||
_CUTLASS 1.3.2 - July 2019_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
@ -28,9 +28,6 @@ CUTLASS 1.3 is described in the [CUTLASS Documentation](CUTLASS.md) and the acco
|
||||
We describe the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
# What's New in CUTLASS 1.3.1
|
||||
_April 2019_
|
||||
* CUTLASS 1.3.1 corrected NVRTC unit tests..
|
||||
|
||||
# What's New in CUTLASS 1.3
|
||||
_March 2019_
|
||||
@ -60,6 +57,8 @@ _September 2018_
|
||||
* [Reference implementations](tools/util/reference) for tensor operations in [host](tools/util/reference/host) and [device](tools/util/reference/device) code
|
||||
* Added `HostMatrix<>` for simplified matrix creation
|
||||
|
||||
For all updates, see the [CUTLASS changelog](CHANGELOG.md).
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>
|
||||
|
@ -34,7 +34,7 @@
|
||||
|
||||
#define CUTLASS_MAJOR 1
|
||||
#define CUTLASS_MINOR 3
|
||||
#define CUTLASS_PATCH 1
|
||||
#define CUTLASS_PATCH 2
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
|
||||
|
@ -237,6 +237,12 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
|
||||
Coord<4> offset = offset_func(ptr_idx);
|
||||
pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
|
||||
}
|
||||
|
||||
if (((threadIdx.x >> 5) * Iterations::kD) & 2) {
|
||||
Scalar *tmp = pointer[0];
|
||||
pointer[0] = pointer[1];
|
||||
pointer[1] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment
|
||||
@ -254,16 +260,12 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) { // 2x STS operations per LDG
|
||||
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
|
||||
int ldg_idx = d + warp_id * Iterations::kD;
|
||||
int k_idx = w + h * 8;
|
||||
int smem_row = (d >> 1);
|
||||
|
||||
// Two store pointers
|
||||
int ptr_idx = ((ldg_idx & 1) ^ ((ldg_idx >> 1) & 1));
|
||||
Scalar *_pointer = pointer[(d & 1) ^ ((d >> 1) & 1)];
|
||||
|
||||
Scalar *_pointer = pointer[ptr_idx];
|
||||
Coord<4> sts_offset = make_Coord(k_idx, smem_row, 0, 0);
|
||||
|
||||
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
|
||||
@ -277,6 +279,7 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
|
||||
|
||||
/// Increments store iterator to next tile
|
||||
__device__ Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
pointer[ptr_idx] +=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(stride);
|
||||
@ -293,6 +296,7 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
|
||||
|
||||
/// Increments store iterator to previous tile
|
||||
__device__ Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
pointer[ptr_idx] -=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(stride);
|
||||
|
@ -183,7 +183,7 @@ TEST(Volta884_f16_s884gemm_128x128x32_tt, short_480x280x224) {
|
||||
// Contiguous - s884gemm
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#if 0
|
||||
|
||||
TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x32) {
|
||||
|
||||
typedef cutlass::gemm::Volta884GemmTraits<
|
||||
@ -218,7 +218,6 @@ TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x30_residue) {
|
||||
run_gemm<GemmTraits>(64, 64, 30);
|
||||
}
|
||||
|
||||
#if 0
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x64) {
|
||||
@ -874,7 +873,6 @@ TEST(Volta884_f16_s884gemm_128x128x32_nn, 392x264x192) {
|
||||
|
||||
run_gemm<GemmTraits>(392, 264, 192);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1281,7 +1279,6 @@ TEST(Volta884_f16_s884gemm_f16_128x256x32_tn, 480x280x224) {
|
||||
|
||||
run_gemm<GemmTraits>(480, 280, 224);
|
||||
}
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
Loading…
Reference in New Issue
Block a user