From fb8b3a98b77f6aae13901a8f5dad73903545dcfd Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Thu, 9 May 2019 14:53:01 -0700 Subject: [PATCH] Addressed code review comments. --- cutlass/gemm/igemm_multiply_add.h | 4 ++-- cutlass/reduction/threadblock_swizzle.h | 2 +- cutlass/wmma_matrix.h | 4 ---- tools/util/half.h | 5 +++++ 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cutlass/gemm/igemm_multiply_add.h b/cutlass/gemm/igemm_multiply_add.h index b6e12cc5..e90b87e0 100644 --- a/cutlass/gemm/igemm_multiply_add.h +++ b/cutlass/gemm/igemm_multiply_add.h @@ -82,9 +82,9 @@ struct ThreadMultiplyAdd int const* a_int = reinterpret_cast(&a[0]); int const* b_int = reinterpret_cast(&b[0]); -#pragma unroll + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { -#pragma unroll + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" diff --git a/cutlass/reduction/threadblock_swizzle.h b/cutlass/reduction/threadblock_swizzle.h index 3b825a9b..f2b796eb 100644 --- a/cutlass/reduction/threadblock_swizzle.h +++ b/cutlass/reduction/threadblock_swizzle.h @@ -35,7 +35,7 @@ struct DefaultBlockSwizzle { CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} /// Swizzle the block index. - CUTLASS_DEVICE dim3 swizzle() { return {blockIdx.x, blockIdx.y, blockIdx.z}; } + CUTLASS_DEVICE dim3 swizzle() { return dim3(blockIdx.x, blockIdx.y, blockIdx.z); } /// CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, diff --git a/cutlass/wmma_matrix.h b/cutlass/wmma_matrix.h index 93857800..bad21149 100644 --- a/cutlass/wmma_matrix.h +++ b/cutlass/wmma_matrix.h @@ -40,11 +40,7 @@ #include "stdio.h" -#if CUDA_VERSION >= 10000 #include -#else -#include -#endif #include "cutlass/fragment.h" #include "cutlass/matrix_traits.h" #include "cutlass/shape.h" diff --git a/tools/util/half.h b/tools/util/half.h index bd4ad06f..48940a8f 100644 --- a/tools/util/half.h +++ b/tools/util/half.h @@ -235,7 +235,9 @@ int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t +#if __cplusplus >= 201103L cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&); +#endif } namespace std { @@ -745,8 +747,11 @@ inline bool signbit(cutlass::half_t const& h) { return h.signbit(); } inline cutlass::half_t sqrt(cutlass::half_t const& h) { return cutlass::half_t(std::sqrt(float(h))); } + +#if __cplusplus >= 201103L inline cutlass::half_t copysign(cutlass::half_t const& a, cutlass::half_t const& b) { return cutlass::half_t(std::copysign(float(a), float(b))); } +#endif } // namespace std