From e18292db46a1b2601ad4af7d33a075c9918f33e4 Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Thu, 2 May 2019 10:40:05 -0700 Subject: [PATCH 1/2] Make CUTLASS compileable with Clang. Requires a recent clang build (r359248 or newer). Enable compilation with clang with these options: cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=/path/to/clang++ --- CMakeLists.txt | 69 ++++++++++++++++----- cutlass/coord.h | 4 +- cutlass/cutlass.h | 4 +- cutlass/gemm/gemm_mainloop.h | 2 +- cutlass/gemm/igemm_multiply_add.h | 2 + cutlass/gemm/linear_scaling.h | 3 +- cutlass/gemm/threadblock_swizzle.h | 2 +- cutlass/reduction/batched_reduction.h | 2 +- cutlass/reduction/threadblock_swizzle.h | 2 +- cutlass/wmma_matrix.h | 8 +-- tools/test/unit/core/layout_verification.cu | 5 ++ tools/test/unit/core/layout_verification.h | 2 + tools/test/unit/core/tile_iterator.cu | 4 +- tools/util/half.h | 20 ++++-- tools/util/reference/detail/inner_product.h | 7 ++- tools/util/reference/device/thread/gemm.h | 2 +- tools/util/type_traits.h | 2 +- 17 files changed, 102 insertions(+), 38 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 25a967b8..b89a0e73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,8 +24,9 @@ cmake_minimum_required(VERSION 3.3.0 FATAL_ERROR) set(CUTLASS_LANGUAGES CXX) +if( CUDA_COMPILER STREQUAL "clang" ) # CMake 3.9.0 has native support for CUDA without the need of the CUDA package. Use it! -if(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0") +elseif(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0") list(APPEND CUTLASS_LANGUAGES CUDA) set(CUTLASS_NATIVE_CUDA TRUE) @@ -48,6 +49,32 @@ endif() project(CUTLASS ${CUTLASS_LANGUAGES}) +if( CUDA_COMPILER STREQUAL "clang" ) + if( NOT CMAKE_CXX_COMPILER_ID STREQUAL "Clang" ) + message(FATAL_ERROR "C++ compiler must be Clang. Currently it's ${CMAKE_CXX_COMPILER_ID}" ) + endif() + string(APPEND CLANG_FLAGS " --std=c++11") + string(APPEND CLANG_FLAGS " --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}") + string(APPEND CLANG_FLAGS " -mllvm -pragma-unroll-threshold=100000") + string(APPEND CLANG_FLAGS " -mllvm -unroll-threshold=5000") + string(APPEND CLANG_FLAGS " -Wno-unused-command-line-argument") + # needed for libcublasLt.so in case it's installed in the same location as libcudart.so + # dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags) + # Otherwise linker uses RUNPATH and that does not propagate to loaded libs. + string(APPEND CLANG_FLAGS " -Wl,--disable-new-dtags") + + link_libraries(${CUDA_CUDART_LIBRARY}) + # Treat CUDA files as C++ files + macro(cutlass_add_executable) + foreach(File ${ARGN}) + if(${File} MATCHES ".*\.cu$") + set_source_files_properties(${File} PROPERTIES LANGUAGE CXX) + endif() + endforeach() + add_executable(${ARGN}) + endmacro() +endif() + # check if the configuration is supported if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 ) message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!") @@ -167,6 +194,7 @@ endif() # Set NVCC arguments foreach(ARCH ${CUTLASS_NVCC_ARCHS}) + string(APPEND CLANG_FLAGS " --cuda-gpu-arch=sm_${ARCH}") if(CUTLASS_NVCC_EMBED_CUBIN) string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}") endif() @@ -175,12 +203,16 @@ foreach(ARCH ${CUTLASS_NVCC_ARCHS}) endif() endforeach() +if(CUTLASS_NVCC_EMBED_PTX) + string(APPEND CLANG_FLAGS " --cuda-include-ptx=all") +endif() + if (CUTLASS_ENABLE_TENSOR_CORE_MMA) - string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1") + string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1") endif() if (CUTLASS_ENABLE_CUBLAS) - string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1") + string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1") endif() if (CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST) @@ -189,6 +221,7 @@ endif() if (CUTLASS_NVCC_KEEP) string(APPEND NVCC_FLAGS " -keep") + string(APPEND CLANG_FLAGS " -save-temps=obj") endif() if (WIN32 AND CUTLASS_NATIVE_CUDA) @@ -196,28 +229,34 @@ if (WIN32 AND CUTLASS_NATIVE_CUDA) else() string(APPEND NVCC_FLAGS " -lineinfo") endif() +string(APPEND CLANG_FLAGS " -gmlt") if (UNIX) string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion") endif() -string(APPEND NVCC_FLAGS_DEBUG " -g") -string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3") -string(APPEND NVCC_FLAGS_RELEASE " -O3") +string(APPEND COMMON_FLAGS_DEBUG " -g") +string(APPEND COMMON_FLAGS_RELWITHDEBINFO " -O3") +string(APPEND COMMON_FLAGS_RELEASE " -O3") # define NDEBUG for release mode to disable assertions string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG") -if (CUTLASS_NATIVE_CUDA) - set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}") - set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}") - set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${NVCC_FLAGS_RELWITHDEBINFO}") - set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}") +if( CUDA_COMPILER STREQUAL "clang" ) + set(CMAKE_CXX_FLAGS "${COMMON_FLAGS} ${CLANG_FLAGS}") + set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE}") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO}") + set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG}") +elseif (CUTLASS_NATIVE_CUDA) + set(CMAKE_CUDA_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS}") + set(CMAKE_CUDA_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE}") + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO}") + set(CMAKE_CUDA_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG}") else() - set(CUDA_NVCC_FLAGS ${NVCC_FLAGS}) - set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG}) - set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${NVCC_FLAGS_RELWITHDEBINFO}) - set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE}) + set(CUDA_NVCC_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS}") + set(CUDA_NVCC_FLAGS_DEBUG ${COMMON_FLAGS_DEBUG}) + set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${COMMON_FLAGS_RELWITHDEBINFO}) + set(CUDA_NVCC_FLAGS_RELEASE ${COMMON_FLAGS_RELEASE}) endif() # diff --git a/cutlass/coord.h b/cutlass/coord.h index 7e91d6e9..9e411ab7 100644 --- a/cutlass/coord.h +++ b/cutlass/coord.h @@ -103,10 +103,10 @@ struct Coord { Coord result; for (int i = 0; i < Slice; ++i) { if (i + start < kRank) { - slice[i] = idx[i + start]; + result[i] = idx[i + start]; } else { - slice[i] = identity; + result[i] = identity; } } return result; diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h index 783ea3b6..a44950cb 100644 --- a/cutlass/cutlass.h +++ b/cutlass/cutlass.h @@ -37,7 +37,7 @@ #define CUTLASS_PATCH 1 #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) -#ifdef __NVCC__ +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) #define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ #define CUTLASS_DEVICE __forceinline__ __device__ #elif defined(__CUDACC_RTC__) @@ -61,7 +61,7 @@ #ifdef __NVCC__ #define CUTLASS_PRAGMA_UNROLL #pragma unroll #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 -#elif defined(__CUDACC_RTC__) +#elif defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") #endif diff --git a/cutlass/gemm/gemm_mainloop.h b/cutlass/gemm/gemm_mainloop.h index a65cb3ae..9857da57 100644 --- a/cutlass/gemm/gemm_mainloop.h +++ b/cutlass/gemm/gemm_mainloop.h @@ -167,7 +167,7 @@ struct GemmMainloop { // Swizzle the IDs of the block (to enable better cache behavior). typename Traits::BlockSwizzle block_swizzle; Coord<3> threadblock_offset = - block_swizzle.get_threadblock_offset(make_Coord_from_shape()); + block_swizzle.get_threadblock_offset(make_Coord_from_shape()); // We may want to use shared memory to clear the registers. typedef typename Traits::ClearAccumulators ClearAccumulators; diff --git a/cutlass/gemm/igemm_multiply_add.h b/cutlass/gemm/igemm_multiply_add.h index 7892850d..b6e12cc5 100644 --- a/cutlass/gemm/igemm_multiply_add.h +++ b/cutlass/gemm/igemm_multiply_add.h @@ -82,7 +82,9 @@ struct ThreadMultiplyAdd int const* a_int = reinterpret_cast(&a[0]); int const* b_int = reinterpret_cast(&b[0]); +#pragma unroll for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { +#pragma unroll for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" diff --git a/cutlass/gemm/linear_scaling.h b/cutlass/gemm/linear_scaling.h index e747b218..4e90241d 100644 --- a/cutlass/gemm/linear_scaling.h +++ b/cutlass/gemm/linear_scaling.h @@ -66,7 +66,8 @@ struct LinearScaling { // Constructor CUTLASS_HOST_DEVICE - Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {} + Params(Scalar _alpha = 0.0f, Scalar _beta = 0.0f) + : alpha(_alpha), beta(_beta) {} /// Initialize the parameters CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) { diff --git a/cutlass/gemm/threadblock_swizzle.h b/cutlass/gemm/threadblock_swizzle.h index 737b89a9..3733d3a0 100644 --- a/cutlass/gemm/threadblock_swizzle.h +++ b/cutlass/gemm/threadblock_swizzle.h @@ -67,7 +67,7 @@ struct IdentityBlockSwizzle { CUTLASS_HOST_DEVICE IdentityBlockSwizzle() {} /// Swizzle the block index. - CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } + CUTLASS_DEVICE dim3 swizzle() { return dim3(blockIdx.x, blockIdx.y, blockIdx.z); } /// CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, diff --git a/cutlass/reduction/batched_reduction.h b/cutlass/reduction/batched_reduction.h index 83324ec0..6915a9a4 100644 --- a/cutlass/reduction/batched_reduction.h +++ b/cutlass/reduction/batched_reduction.h @@ -70,7 +70,7 @@ struct BatchedReduction { // Swizzle the IDs of the block typename Traits::BlockSwizzle block_swizzle; Coord<3> threadblock_offset = - block_swizzle.get_threadblock_offset(make_Coord_from_shape()); + block_swizzle.get_threadblock_offset(make_Coord_from_shape()); int subTileSize = gridDim.x * Traits::SubTile::kW; int tileSize = params.problem_size[1] * params.problem_size[2]; diff --git a/cutlass/reduction/threadblock_swizzle.h b/cutlass/reduction/threadblock_swizzle.h index 6e42cada..3b825a9b 100644 --- a/cutlass/reduction/threadblock_swizzle.h +++ b/cutlass/reduction/threadblock_swizzle.h @@ -35,7 +35,7 @@ struct DefaultBlockSwizzle { CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} /// Swizzle the block index. - CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } + CUTLASS_DEVICE dim3 swizzle() { return {blockIdx.x, blockIdx.y, blockIdx.z}; } /// CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, diff --git a/cutlass/wmma_matrix.h b/cutlass/wmma_matrix.h index 647acd80..93857800 100644 --- a/cutlass/wmma_matrix.h +++ b/cutlass/wmma_matrix.h @@ -30,20 +30,20 @@ #if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700) #define CUTLASS_USE_WMMA_API -#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 720) +#if defined(__CUDACC__) && (CUDA_VERSION >= 10000) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 720) #define CUTLASS_USE_INT_WMMA #endif -#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750) +#if defined(__CUDACC__) && (CUDA_VERSION >= 10000) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750) #define CUTLASS_USE_SUBBYTE_WMMA #endif #include "stdio.h" -#if __CUDACC_VER_MAJOR__ >= 10 +#if CUDA_VERSION >= 10000 #include #else -#include +#include #endif #include "cutlass/fragment.h" #include "cutlass/matrix_traits.h" diff --git a/tools/test/unit/core/layout_verification.cu b/tools/test/unit/core/layout_verification.cu index 76c1d7c6..2f007c2e 100644 --- a/tools/test/unit/core/layout_verification.cu +++ b/tools/test/unit/core/layout_verification.cu @@ -142,12 +142,17 @@ int Layout::operator()(Layout::Coordinate const &_coord) const { } +// test::Layout::Coordinate is actually a std::vector<>, so for ADL lookup to +// work, the operator<< must be in std::. GCC does look it up in global +// namespace, but that's a bug. +namespace std { std::ostream & operator<<(std::ostream &out, test::Layout::Coordinate const &coord) { for (int i = 0; i < coord.size(); ++i) { out << (i ? ", " : "") << coord.at(i); } return out; } +} // namespace std /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/core/layout_verification.h b/tools/test/unit/core/layout_verification.h index 86222aca..5d4541de 100644 --- a/tools/test/unit/core/layout_verification.h +++ b/tools/test/unit/core/layout_verification.h @@ -94,7 +94,9 @@ class Layout { } /// Implemented in layout_verification.cu +namespace std { std::ostream& operator<<(std::ostream& out, test::Layout::Coordinate const& coord); +} namespace test { diff --git a/tools/test/unit/core/tile_iterator.cu b/tools/test/unit/core/tile_iterator.cu index eabc2349..0a1d86a7 100644 --- a/tools/test/unit/core/tile_iterator.cu +++ b/tools/test/unit/core/tile_iterator.cu @@ -42,9 +42,7 @@ __global__ void load_store_global( typename cutlass::TileStoreIterator::Scalar *output, int kW, - int kH, - typename cutlass::TileStoreIterator::Scalar identity = 0 + int kH ) { /// Load iterator diff --git a/tools/util/half.h b/tools/util/half.h index 4d5b5743..bd4ad06f 100644 --- a/tools/util/half.h +++ b/tools/util/half.h @@ -86,7 +86,7 @@ class half_t { half_t operator+(half_t const&) const; half_t operator-() const; half_t operator-(half_t const&) const; - half_t operator*(half_t const&)const; + half_t operator*(half_t const&) const; half_t operator/(half_t const&) const; half_t& operator+=(half_t const&); @@ -107,6 +107,12 @@ class half_t { uint16_t& raw() { return x; } uint16_t raw() const { return x; } +#if defined(__clang__) + __device__ half_t operator+(half_t const&) const; + __device__ half_t operator*(half_t const&) const; + __device__ operator float() const; /// conversion to fp32 +#endif + // // Stream interactions // @@ -209,7 +215,7 @@ std::string lexical_cast(cutlass::half_t const& arg); #define HLF_MANT_DIG 10 -namespace std { +namespace cutlass { cutlass::half_t abs(cutlass::half_t const&); /// absolute value @@ -229,7 +235,10 @@ int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t +cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&); +} +namespace std { /// Numeric limits template <> struct numeric_limits { @@ -696,8 +705,7 @@ std::string lexical_cast(cutlass::half_t const& arg) { // Standard Library Operations // -// std -namespace std { +namespace cutlass { inline cutlass::half_t abs(cutlass::half_t const& h) { return cutlass::half_t::bitcast(h.x & 0x7fff); @@ -737,4 +745,8 @@ inline bool signbit(cutlass::half_t const& h) { return h.signbit(); } inline cutlass::half_t sqrt(cutlass::half_t const& h) { return cutlass::half_t(std::sqrt(float(h))); } +inline cutlass::half_t copysign(cutlass::half_t const& a, + cutlass::half_t const& b) { + return cutlass::half_t(std::copysign(float(a), float(b))); +} } // namespace std diff --git a/tools/util/reference/detail/inner_product.h b/tools/util/reference/detail/inner_product.h index 26ebe1e8..7a19b6d1 100644 --- a/tools/util/reference/detail/inner_product.h +++ b/tools/util/reference/detail/inner_product.h @@ -45,6 +45,12 @@ Ctype inner_product(Atype a, Btype b, Ctype c) { return Ctype(a) * Ctype(b) + c; } +#if defined(__clang__) && defined(__CUDA__) +__device__ __forceinline__ __half inner_product(__half a, __half b, __half c) { + return a * b + c; +} +#endif + /// Specialization for matrix multiplication with binary operands template <> CUTLASS_HOST_DEVICE @@ -124,4 +130,3 @@ struct Cast { } // namespace detail } // namespace reference } // namespace cutlass - diff --git a/tools/util/reference/device/thread/gemm.h b/tools/util/reference/device/thread/gemm.h index 05e8262c..c03003ca 100644 --- a/tools/util/reference/device/thread/gemm.h +++ b/tools/util/reference/device/thread/gemm.h @@ -142,7 +142,7 @@ struct Gemm { } /// Performs linear scaling of matrix product and updates output tensor - CUTLASS_HOST_DEVICE + __device__ Gemm & epilogue( gemm::GemmCoord problem_size, ScalarType alpha, diff --git a/tools/util/type_traits.h b/tools/util/type_traits.h index 7a82e52b..7915264c 100644 --- a/tools/util/type_traits.h +++ b/tools/util/type_traits.h @@ -140,7 +140,7 @@ struct TypeTraits { typedef int16_t integer_type; typedef uint16_t unsigned_type; static inline half remove_negative_zero(half x) { - integer_type h_int = reinterpret_cast(x); + unsigned_type h_int = reinterpret_cast(x); if (h_int == 0x8000) { h_int = 0; } From fb8b3a98b77f6aae13901a8f5dad73903545dcfd Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Thu, 9 May 2019 14:53:01 -0700 Subject: [PATCH 2/2] Addressed code review comments. --- cutlass/gemm/igemm_multiply_add.h | 4 ++-- cutlass/reduction/threadblock_swizzle.h | 2 +- cutlass/wmma_matrix.h | 4 ---- tools/util/half.h | 5 +++++ 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cutlass/gemm/igemm_multiply_add.h b/cutlass/gemm/igemm_multiply_add.h index b6e12cc5..e90b87e0 100644 --- a/cutlass/gemm/igemm_multiply_add.h +++ b/cutlass/gemm/igemm_multiply_add.h @@ -82,9 +82,9 @@ struct ThreadMultiplyAdd int const* a_int = reinterpret_cast(&a[0]); int const* b_int = reinterpret_cast(&b[0]); -#pragma unroll + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { -#pragma unroll + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" diff --git a/cutlass/reduction/threadblock_swizzle.h b/cutlass/reduction/threadblock_swizzle.h index 3b825a9b..f2b796eb 100644 --- a/cutlass/reduction/threadblock_swizzle.h +++ b/cutlass/reduction/threadblock_swizzle.h @@ -35,7 +35,7 @@ struct DefaultBlockSwizzle { CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} /// Swizzle the block index. - CUTLASS_DEVICE dim3 swizzle() { return {blockIdx.x, blockIdx.y, blockIdx.z}; } + CUTLASS_DEVICE dim3 swizzle() { return dim3(blockIdx.x, blockIdx.y, blockIdx.z); } /// CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, diff --git a/cutlass/wmma_matrix.h b/cutlass/wmma_matrix.h index 93857800..bad21149 100644 --- a/cutlass/wmma_matrix.h +++ b/cutlass/wmma_matrix.h @@ -40,11 +40,7 @@ #include "stdio.h" -#if CUDA_VERSION >= 10000 #include -#else -#include -#endif #include "cutlass/fragment.h" #include "cutlass/matrix_traits.h" #include "cutlass/shape.h" diff --git a/tools/util/half.h b/tools/util/half.h index bd4ad06f..48940a8f 100644 --- a/tools/util/half.h +++ b/tools/util/half.h @@ -235,7 +235,9 @@ int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t +#if __cplusplus >= 201103L cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&); +#endif } namespace std { @@ -745,8 +747,11 @@ inline bool signbit(cutlass::half_t const& h) { return h.signbit(); } inline cutlass::half_t sqrt(cutlass::half_t const& h) { return cutlass::half_t(std::sqrt(float(h))); } + +#if __cplusplus >= 201103L inline cutlass::half_t copysign(cutlass::half_t const& a, cutlass::half_t const& b) { return cutlass::half_t(std::copysign(float(a), float(b))); } +#endif } // namespace std