diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index 205951fe..69f0a855 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -36,39 +36,43 @@ #if defined(__clang__) && defined(__CUDA__) // __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665 - #define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED (__clang_major__ >= 14) + #if __clang_major__ >= 14 + #define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif - #ifndef _WIN32 // __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 - #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 14) - #else - // ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 - #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 15) + // ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 + #if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15 + #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1 #endif #endif #if defined(__NVCC__) || defined(__CUDACC_RTC__) // __cvta_generic_to_shared added in CUDA 11+ - #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) + #if __CUDACC_VER_MAJOR__ >= 11 #define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 #endif // __nvvm_get_smem_pointer added in CUDA 10.2 - #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 + #if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 #define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1 #endif #endif -#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED (CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED) - -#ifndef CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED - #define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED +#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED + #define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1 #endif -#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED (CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER) +#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1 +#endif -#ifndef CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED - #define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER + #define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1 +#endif + +#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1 #endif // Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need @@ -85,7 +89,7 @@ namespace cute { /// CUTE helper to cast SMEM pointer to unsigned -CUTE_HOST_DEVICE +CUTE_DEVICE uint32_t cast_smem_ptr_to_uint(void const* const ptr) {