Make CUTLASS compileable with Clang.

Requires a recent clang build (r359248 or newer).

Enable compilation with clang with these options:
cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=/path/to/clang++
This commit is contained in:
Artem Belevich 2019-05-02 10:40:05 -07:00
parent fe3438a3c1
commit e18292db46
17 changed files with 102 additions and 38 deletions

View File

@ -24,8 +24,9 @@ cmake_minimum_required(VERSION 3.3.0 FATAL_ERROR)
set(CUTLASS_LANGUAGES CXX) set(CUTLASS_LANGUAGES CXX)
if( CUDA_COMPILER STREQUAL "clang" )
# CMake 3.9.0 has native support for CUDA without the need of the CUDA package. Use it! # CMake 3.9.0 has native support for CUDA without the need of the CUDA package. Use it!
if(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0") elseif(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0")
list(APPEND CUTLASS_LANGUAGES CUDA) list(APPEND CUTLASS_LANGUAGES CUDA)
set(CUTLASS_NATIVE_CUDA TRUE) set(CUTLASS_NATIVE_CUDA TRUE)
@ -48,6 +49,32 @@ endif()
project(CUTLASS ${CUTLASS_LANGUAGES}) project(CUTLASS ${CUTLASS_LANGUAGES})
if( CUDA_COMPILER STREQUAL "clang" )
if( NOT CMAKE_CXX_COMPILER_ID STREQUAL "Clang" )
message(FATAL_ERROR "C++ compiler must be Clang. Currently it's ${CMAKE_CXX_COMPILER_ID}" )
endif()
string(APPEND CLANG_FLAGS " --std=c++11")
string(APPEND CLANG_FLAGS " --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}")
string(APPEND CLANG_FLAGS " -mllvm -pragma-unroll-threshold=100000")
string(APPEND CLANG_FLAGS " -mllvm -unroll-threshold=5000")
string(APPEND CLANG_FLAGS " -Wno-unused-command-line-argument")
# needed for libcublasLt.so in case it's installed in the same location as libcudart.so
# dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags)
# Otherwise linker uses RUNPATH and that does not propagate to loaded libs.
string(APPEND CLANG_FLAGS " -Wl,--disable-new-dtags")
link_libraries(${CUDA_CUDART_LIBRARY})
# Treat CUDA files as C++ files
macro(cutlass_add_executable)
foreach(File ${ARGN})
if(${File} MATCHES ".*\.cu$")
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
endif()
endforeach()
add_executable(${ARGN})
endmacro()
endif()
# check if the configuration is supported # check if the configuration is supported
if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 ) if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!") message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
@ -167,6 +194,7 @@ endif()
# Set NVCC arguments # Set NVCC arguments
foreach(ARCH ${CUTLASS_NVCC_ARCHS}) foreach(ARCH ${CUTLASS_NVCC_ARCHS})
string(APPEND CLANG_FLAGS " --cuda-gpu-arch=sm_${ARCH}")
if(CUTLASS_NVCC_EMBED_CUBIN) if(CUTLASS_NVCC_EMBED_CUBIN)
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}") string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
endif() endif()
@ -175,12 +203,16 @@ foreach(ARCH ${CUTLASS_NVCC_ARCHS})
endif() endif()
endforeach() endforeach()
if(CUTLASS_NVCC_EMBED_PTX)
string(APPEND CLANG_FLAGS " --cuda-include-ptx=all")
endif()
if (CUTLASS_ENABLE_TENSOR_CORE_MMA) if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1") string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1")
endif() endif()
if (CUTLASS_ENABLE_CUBLAS) if (CUTLASS_ENABLE_CUBLAS)
string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1") string(APPEND COMMON_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1")
endif() endif()
if (CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST) if (CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST)
@ -189,6 +221,7 @@ endif()
if (CUTLASS_NVCC_KEEP) if (CUTLASS_NVCC_KEEP)
string(APPEND NVCC_FLAGS " -keep") string(APPEND NVCC_FLAGS " -keep")
string(APPEND CLANG_FLAGS " -save-temps=obj")
endif() endif()
if (WIN32 AND CUTLASS_NATIVE_CUDA) if (WIN32 AND CUTLASS_NATIVE_CUDA)
@ -196,28 +229,34 @@ if (WIN32 AND CUTLASS_NATIVE_CUDA)
else() else()
string(APPEND NVCC_FLAGS " -lineinfo") string(APPEND NVCC_FLAGS " -lineinfo")
endif() endif()
string(APPEND CLANG_FLAGS " -gmlt")
if (UNIX) if (UNIX)
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion") string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
endif() endif()
string(APPEND NVCC_FLAGS_DEBUG " -g") string(APPEND COMMON_FLAGS_DEBUG " -g")
string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3") string(APPEND COMMON_FLAGS_RELWITHDEBINFO " -O3")
string(APPEND NVCC_FLAGS_RELEASE " -O3") string(APPEND COMMON_FLAGS_RELEASE " -O3")
# define NDEBUG for release mode to disable assertions # define NDEBUG for release mode to disable assertions
string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG") string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG")
if (CUTLASS_NATIVE_CUDA) if( CUDA_COMPILER STREQUAL "clang" )
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}") set(CMAKE_CXX_FLAGS "${COMMON_FLAGS} ${CLANG_FLAGS}")
set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}") set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE}")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${NVCC_FLAGS_RELWITHDEBINFO}") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO}")
set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG}")
elseif (CUTLASS_NATIVE_CUDA)
set(CMAKE_CUDA_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS}")
set(CMAKE_CUDA_FLAGS_RELEASE "${COMMON_FLAGS_RELEASE}")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS_RELWITHDEBINFO}")
set(CMAKE_CUDA_FLAGS_DEBUG "${COMMON_FLAGS_DEBUG}")
else() else()
set(CUDA_NVCC_FLAGS ${NVCC_FLAGS}) set(CUDA_NVCC_FLAGS "${COMMON_FLAGS} ${NVCC_FLAGS}")
set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG}) set(CUDA_NVCC_FLAGS_DEBUG ${COMMON_FLAGS_DEBUG})
set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${NVCC_FLAGS_RELWITHDEBINFO}) set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${COMMON_FLAGS_RELWITHDEBINFO})
set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE}) set(CUDA_NVCC_FLAGS_RELEASE ${COMMON_FLAGS_RELEASE})
endif() endif()
# #

View File

@ -103,10 +103,10 @@ struct Coord {
Coord<Slice> result; Coord<Slice> result;
for (int i = 0; i < Slice; ++i) { for (int i = 0; i < Slice; ++i) {
if (i + start < kRank) { if (i + start < kRank) {
slice[i] = idx[i + start]; result[i] = idx[i + start];
} }
else { else {
slice[i] = identity; result[i] = identity;
} }
} }
return result; return result;

View File

@ -37,7 +37,7 @@
#define CUTLASS_PATCH 1 #define CUTLASS_PATCH 1
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
#ifdef __NVCC__ #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ #define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
#define CUTLASS_DEVICE __forceinline__ __device__ #define CUTLASS_DEVICE __forceinline__ __device__
#elif defined(__CUDACC_RTC__) #elif defined(__CUDACC_RTC__)
@ -61,7 +61,7 @@
#ifdef __NVCC__ #ifdef __NVCC__
#define CUTLASS_PRAGMA_UNROLL #pragma unroll #define CUTLASS_PRAGMA_UNROLL #pragma unroll
#define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1
#elif defined(__CUDACC_RTC__) #elif defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
#endif #endif

View File

@ -167,7 +167,7 @@ struct GemmMainloop {
// Swizzle the IDs of the block (to enable better cache behavior). // Swizzle the IDs of the block (to enable better cache behavior).
typename Traits::BlockSwizzle block_swizzle; typename Traits::BlockSwizzle block_swizzle;
Coord<3> threadblock_offset = Coord<3> threadblock_offset =
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>()); block_swizzle.get_threadblock_offset(make_Coord_from_shape<typename Traits::OutputTile>());
// We may want to use shared memory to clear the registers. // We may want to use shared memory to clear the registers.
typedef typename Traits::ClearAccumulators ClearAccumulators; typedef typename Traits::ClearAccumulators ClearAccumulators;

View File

@ -82,7 +82,9 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
int const* a_int = reinterpret_cast<int const*>(&a[0]); int const* a_int = reinterpret_cast<int const*>(&a[0]);
int const* b_int = reinterpret_cast<int const*>(&b[0]); int const* b_int = reinterpret_cast<int const*>(&b[0]);
#pragma unroll
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
#pragma unroll
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"

View File

@ -66,7 +66,8 @@ struct LinearScaling {
// Constructor // Constructor
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {} Params(Scalar _alpha = 0.0f, Scalar _beta = 0.0f)
: alpha(_alpha), beta(_beta) {}
/// Initialize the parameters /// Initialize the parameters
CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) { CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) {

View File

@ -67,7 +67,7 @@ struct IdentityBlockSwizzle {
CUTLASS_HOST_DEVICE IdentityBlockSwizzle() {} CUTLASS_HOST_DEVICE IdentityBlockSwizzle() {}
/// Swizzle the block index. /// Swizzle the block index.
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } CUTLASS_DEVICE dim3 swizzle() { return dim3(blockIdx.x, blockIdx.y, blockIdx.z); }
/// ///
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,

View File

@ -70,7 +70,7 @@ struct BatchedReduction {
// Swizzle the IDs of the block // Swizzle the IDs of the block
typename Traits::BlockSwizzle block_swizzle; typename Traits::BlockSwizzle block_swizzle;
Coord<3> threadblock_offset = Coord<3> threadblock_offset =
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::SubTile>()); block_swizzle.get_threadblock_offset(make_Coord_from_shape<typename Traits::SubTile>());
int subTileSize = gridDim.x * Traits::SubTile::kW; int subTileSize = gridDim.x * Traits::SubTile::kW;
int tileSize = params.problem_size[1] * params.problem_size[2]; int tileSize = params.problem_size[1] * params.problem_size[2];

View File

@ -35,7 +35,7 @@ struct DefaultBlockSwizzle {
CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
/// Swizzle the block index. /// Swizzle the block index.
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } CUTLASS_DEVICE dim3 swizzle() { return {blockIdx.x, blockIdx.y, blockIdx.z}; }
/// ///
CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,

View File

@ -30,20 +30,20 @@
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700) #if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
#define CUTLASS_USE_WMMA_API #define CUTLASS_USE_WMMA_API
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 720) #if defined(__CUDACC__) && (CUDA_VERSION >= 10000) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 720)
#define CUTLASS_USE_INT_WMMA #define CUTLASS_USE_INT_WMMA
#endif #endif
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750) #if defined(__CUDACC__) && (CUDA_VERSION >= 10000) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
#define CUTLASS_USE_SUBBYTE_WMMA #define CUTLASS_USE_SUBBYTE_WMMA
#endif #endif
#include "stdio.h" #include "stdio.h"
#if __CUDACC_VER_MAJOR__ >= 10 #if CUDA_VERSION >= 10000
#include <mma.h> #include <mma.h>
#else #else
#include <crt/mma.h> #include <mma.h>
#endif #endif
#include "cutlass/fragment.h" #include "cutlass/fragment.h"
#include "cutlass/matrix_traits.h" #include "cutlass/matrix_traits.h"

View File

@ -142,12 +142,17 @@ int Layout::operator()(Layout::Coordinate const &_coord) const {
} }
// test::Layout::Coordinate is actually a std::vector<>, so for ADL lookup to
// work, the operator<< must be in std::. GCC does look it up in global
// namespace, but that's a bug.
namespace std {
std::ostream & operator<<(std::ostream &out, test::Layout::Coordinate const &coord) { std::ostream & operator<<(std::ostream &out, test::Layout::Coordinate const &coord) {
for (int i = 0; i < coord.size(); ++i) { for (int i = 0; i < coord.size(); ++i) {
out << (i ? ", " : "") << coord.at(i); out << (i ? ", " : "") << coord.at(i);
} }
return out; return out;
} }
} // namespace std
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -94,7 +94,9 @@ class Layout {
} }
/// Implemented in layout_verification.cu /// Implemented in layout_verification.cu
namespace std {
std::ostream& operator<<(std::ostream& out, test::Layout::Coordinate const& coord); std::ostream& operator<<(std::ostream& out, test::Layout::Coordinate const& coord);
}
namespace test { namespace test {

View File

@ -42,9 +42,7 @@ __global__ void load_store_global(
typename cutlass::TileStoreIterator<Traits, Scalar, cutlass::IteratorAdvance::kH, typename cutlass::TileStoreIterator<Traits, Scalar, cutlass::IteratorAdvance::kH,
cutlass::MemorySpace::kGlobal>::Scalar *output, cutlass::MemorySpace::kGlobal>::Scalar *output,
int kW, int kW,
int kH, int kH
typename cutlass::TileStoreIterator<Traits, Scalar, cutlass::IteratorAdvance::kH,
cutlass::MemorySpace::kGlobal>::Scalar identity = 0
) { ) {
/// Load iterator /// Load iterator

View File

@ -86,7 +86,7 @@ class half_t {
half_t operator+(half_t const&) const; half_t operator+(half_t const&) const;
half_t operator-() const; half_t operator-() const;
half_t operator-(half_t const&) const; half_t operator-(half_t const&) const;
half_t operator*(half_t const&)const; half_t operator*(half_t const&) const;
half_t operator/(half_t const&) const; half_t operator/(half_t const&) const;
half_t& operator+=(half_t const&); half_t& operator+=(half_t const&);
@ -107,6 +107,12 @@ class half_t {
uint16_t& raw() { return x; } uint16_t& raw() { return x; }
uint16_t raw() const { return x; } uint16_t raw() const { return x; }
#if defined(__clang__)
__device__ half_t operator+(half_t const&) const;
__device__ half_t operator*(half_t const&) const;
__device__ operator float() const; /// conversion to fp32
#endif
// //
// Stream interactions // Stream interactions
// //
@ -209,7 +215,7 @@ std::string lexical_cast<std::string>(cutlass::half_t const& arg);
#define HLF_MANT_DIG 10 #define HLF_MANT_DIG 10
namespace std { namespace cutlass {
cutlass::half_t abs(cutlass::half_t const&); /// absolute value cutlass::half_t abs(cutlass::half_t const&); /// absolute value
@ -229,7 +235,10 @@ int fpclassify(cutlass::half_t const&); /// returns a flag classifying floating
bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive bool signbit(cutlass::half_t const&); /// returns true if negative, false if positive
cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t cutlass::half_t sqrt(cutlass::half_t const&); /// square root of half_t
cutlass::half_t copysign(cutlass::half_t const&, cutlass::half_t const&);
}
namespace std {
/// Numeric limits /// Numeric limits
template <> template <>
struct numeric_limits<cutlass::half_t> { struct numeric_limits<cutlass::half_t> {
@ -696,8 +705,7 @@ std::string lexical_cast<std::string>(cutlass::half_t const& arg) {
// Standard Library Operations // Standard Library Operations
// //
// std namespace cutlass {
namespace std {
inline cutlass::half_t abs(cutlass::half_t const& h) { inline cutlass::half_t abs(cutlass::half_t const& h) {
return cutlass::half_t::bitcast(h.x & 0x7fff); return cutlass::half_t::bitcast(h.x & 0x7fff);
@ -737,4 +745,8 @@ inline bool signbit(cutlass::half_t const& h) { return h.signbit(); }
inline cutlass::half_t sqrt(cutlass::half_t const& h) { inline cutlass::half_t sqrt(cutlass::half_t const& h) {
return cutlass::half_t(std::sqrt(float(h))); return cutlass::half_t(std::sqrt(float(h)));
} }
inline cutlass::half_t copysign(cutlass::half_t const& a,
cutlass::half_t const& b) {
return cutlass::half_t(std::copysign(float(a), float(b)));
}
} // namespace std } // namespace std

View File

@ -45,6 +45,12 @@ Ctype inner_product(Atype a, Btype b, Ctype c) {
return Ctype(a) * Ctype(b) + c; return Ctype(a) * Ctype(b) + c;
} }
#if defined(__clang__) && defined(__CUDA__)
__device__ __forceinline__ __half inner_product(__half a, __half b, __half c) {
return a * b + c;
}
#endif
/// Specialization for matrix multiplication with binary operands /// Specialization for matrix multiplication with binary operands
template <> template <>
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
@ -124,4 +130,3 @@ struct Cast<float, uint8_t> {
} // namespace detail } // namespace detail
} // namespace reference } // namespace reference
} // namespace cutlass } // namespace cutlass

View File

@ -142,7 +142,7 @@ struct Gemm {
} }
/// Performs linear scaling of matrix product and updates output tensor /// Performs linear scaling of matrix product and updates output tensor
CUTLASS_HOST_DEVICE __device__
Gemm & epilogue( Gemm & epilogue(
gemm::GemmCoord problem_size, gemm::GemmCoord problem_size,
ScalarType alpha, ScalarType alpha,

View File

@ -140,7 +140,7 @@ struct TypeTraits<half> {
typedef int16_t integer_type; typedef int16_t integer_type;
typedef uint16_t unsigned_type; typedef uint16_t unsigned_type;
static inline half remove_negative_zero(half x) { static inline half remove_negative_zero(half x) {
integer_type h_int = reinterpret_cast<integer_type const&>(x); unsigned_type h_int = reinterpret_cast<unsigned_type const&>(x);
if (h_int == 0x8000) { if (h_int == 0x8000) {
h_int = 0; h_int = 0;
} }