Resolved issue for incorrect SGEMM on Maxwell architecture.

This commit is contained in:
akerr 2018-12-19 15:07:16 -08:00
parent ed2ed4d667
commit 822b0952cd
4 changed files with 29 additions and 6 deletions

View File

@ -1,5 +1,8 @@
# NVIDIA CUTLASS Changelog
## [1.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.1) (2018-12-19)
* Resolved issue with sm50 and sm52 architectures
## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
* Parallelized reductions across threadblocks ("Split-K")
* Improved IGEMM performance

View File

@ -33,7 +33,7 @@
#define CUTLASS_MAJOR 1
#define CUTLASS_MINOR 2
#define CUTLASS_PATCH 0
#define CUTLASS_PATCH 1
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
#ifdef __NVCC__

View File

@ -52,7 +52,6 @@ struct FragmentMultiplyAdd {
/// Multiply : d = a*b.
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = b[j * kReduction + 0];
@ -61,7 +60,6 @@ struct FragmentMultiplyAdd {
}
d[j] = a * ScalarAlphaBeta(d[j]);
}
#endif
}
/// Multiply : d = a*b + c.
@ -70,7 +68,7 @@ struct FragmentMultiplyAdd {
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = b[j * kReduction + 0];
@ -79,7 +77,6 @@ struct FragmentMultiplyAdd {
}
d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]);
}
#endif
}
};

View File

@ -77,6 +77,8 @@
#include <sstream>
#include <vector>
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__) >= 530
// CUTLASS includes needed for mixed-precision GEMM kernel
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/fp16_sgemm_traits.h"
@ -312,6 +314,24 @@ cudaError_t TestCutlassGemm(int M, int N, int K, cutlass::half_t alpha, cutlass:
//
int main(int argc, const char *arg[]) {
//
// This example uses half-precision and is only suitable for devices with compute capabitliy 5.3 or greater.
//
cudaDeviceProp prop;
cudaError_t result = cudaGetDeviceProperties(&prop, 0);
if (result != cudaSuccess) {
std::cerr << "Failed to query device properties with error " << cudaGetErrorString(result) << std::endl;
return -1;
}
if (!(prop.major > 5 || (prop.major == 5 && prop.minor >= 3))) {
std::cerr << "This example uses mixed precision and is only suitable for devices with compute capability 5.3 or greater.\n";
std::cerr << "You are using a CUDA device with compute capability " << prop.major << "." << prop.minor << std::endl;
return -1;
}
//
// Parse the command line to obtain GEMM dimensions and scalar values.
//
@ -341,7 +361,7 @@ int main(int argc, const char *arg[]) {
// Run the CUTLASS GEMM test.
//
cudaError_t result = TestCutlassGemm(
result = TestCutlassGemm(
problem[0], // GEMM M dimension
problem[1], // GEMM N dimension
problem[2], // GEMM K dimension
@ -358,3 +378,6 @@ int main(int argc, const char *arg[]) {
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#endif