From 822b0952cd2464077c17a2ef848d2bad483affc7 Mon Sep 17 00:00:00 2001 From: akerr Date: Wed, 19 Dec 2018 15:07:16 -0800 Subject: [PATCH] Resolved issue for incorrect SGEMM on Maxwell architecture. --- CHANGELOG.md | 3 +++ cutlass/cutlass.h | 2 +- cutlass/fragment_multiply_add.h | 5 +--- .../02_cutlass_utilities/cutlass_utilities.cu | 25 ++++++++++++++++++- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 311c72b0..b956c073 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # NVIDIA CUTLASS Changelog +## [1.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.1) (2018-12-19) + * Resolved issue with sm50 and sm52 architectures + ## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26) * Parallelized reductions across threadblocks ("Split-K") * Improved IGEMM performance diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h index 2851a5f0..ac5420d7 100644 --- a/cutlass/cutlass.h +++ b/cutlass/cutlass.h @@ -33,7 +33,7 @@ #define CUTLASS_MAJOR 1 #define CUTLASS_MINOR 2 -#define CUTLASS_PATCH 0 +#define CUTLASS_PATCH 1 #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) #ifdef __NVCC__ diff --git a/cutlass/fragment_multiply_add.h b/cutlass/fragment_multiply_add.h index de2c8052..8bcf8120 100644 --- a/cutlass/fragment_multiply_add.h +++ b/cutlass/fragment_multiply_add.h @@ -52,7 +52,6 @@ struct FragmentMultiplyAdd { /// Multiply : d = a*b. template CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) { -#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; for (int j = 0; j < FragmentCd_::kElements; ++j) { d[j] = b[j * kReduction + 0]; @@ -61,7 +60,6 @@ struct FragmentMultiplyAdd { } d[j] = a * ScalarAlphaBeta(d[j]); } -#endif } /// Multiply : d = a*b + c. @@ -70,7 +68,7 @@ struct FragmentMultiplyAdd { FragmentB_ const& b, FragmentCd_ const& c, FragmentCd_& d) { -#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 + int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; for (int j = 0; j < FragmentCd_::kElements; ++j) { d[j] = b[j * kReduction + 0]; @@ -79,7 +77,6 @@ struct FragmentMultiplyAdd { } d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]); } -#endif } }; diff --git a/examples/02_cutlass_utilities/cutlass_utilities.cu b/examples/02_cutlass_utilities/cutlass_utilities.cu index 7ca79c80..7f04cc57 100644 --- a/examples/02_cutlass_utilities/cutlass_utilities.cu +++ b/examples/02_cutlass_utilities/cutlass_utilities.cu @@ -77,6 +77,8 @@ #include #include +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__) >= 530 + // CUTLASS includes needed for mixed-precision GEMM kernel #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/fp16_sgemm_traits.h" @@ -312,6 +314,24 @@ cudaError_t TestCutlassGemm(int M, int N, int K, cutlass::half_t alpha, cutlass: // int main(int argc, const char *arg[]) { + // + // This example uses half-precision and is only suitable for devices with compute capabitliy 5.3 or greater. + // + + cudaDeviceProp prop; + cudaError_t result = cudaGetDeviceProperties(&prop, 0); + + if (result != cudaSuccess) { + std::cerr << "Failed to query device properties with error " << cudaGetErrorString(result) << std::endl; + return -1; + } + + if (!(prop.major > 5 || (prop.major == 5 && prop.minor >= 3))) { + std::cerr << "This example uses mixed precision and is only suitable for devices with compute capability 5.3 or greater.\n"; + std::cerr << "You are using a CUDA device with compute capability " << prop.major << "." << prop.minor << std::endl; + return -1; + } + // // Parse the command line to obtain GEMM dimensions and scalar values. // @@ -341,7 +361,7 @@ int main(int argc, const char *arg[]) { // Run the CUTLASS GEMM test. // - cudaError_t result = TestCutlassGemm( + result = TestCutlassGemm( problem[0], // GEMM M dimension problem[1], // GEMM N dimension problem[2], // GEMM K dimension @@ -358,3 +378,6 @@ int main(int argc, const char *arg[]) { } /////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif +