Fix host compilation of cute::cast_smem_ptr_to_uint. (#940)

* Remove references to device-only intrinsics when compiling for host.

Currently, we attempt to use the `__device__`-only functions
`__cvta_generic_to_shared` and `__nvvm_get_smem_pointer` when compiling
`cute::cast_smem_ptr_to_uint` for the host on Clang. This results in a
compilation error, as expected. This commit changes the definition of
the `*_ACTIVATED` macros so that they are only true when `__CUDA_ARCH__`
is defined; that is, when compiling for the device.

Additionally, the declaration of `__nvvm_get_smem_pointer`
is currently only visible during the device compilation pass when
compiling with NVCC; this commit makes the declaration visible during
host compilation with the `__device__` annotation.

* Annotate cute::cast_smem_ptr_to_uint as device-only.

The implementation of `cute::cast_smem_ptr_to_uint` is currently an
unchecked failure on host code, and the only host implementation I can
think of -- casting a probably-64-bit pointer to 32 bits somehow --
doesn't make sense to implement. This commit marks this function as
device-only so that it can't be accidentally used on host code.

* small change

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Gregory Meyer (gregjm) 2023-05-09 21:06:54 -07:00 committed by GitHub
parent b250faccd3
commit fcfbd23e26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,39 +36,43 @@
#if defined(__clang__) && defined(__CUDA__)
// __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665
#define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED (__clang_major__ >= 14)
#if __clang_major__ >= 14
#define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1
#endif
#ifndef _WIN32
// __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 14)
#else
// ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 15)
// ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897
#if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1
#endif
#endif
#if defined(__NVCC__) || defined(__CUDACC_RTC__)
// __cvta_generic_to_shared added in CUDA 11+
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11)
#if __CUDACC_VER_MAJOR__ >= 11
#define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1
#endif
// __nvvm_get_smem_pointer added in CUDA 10.2
#if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2
#if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2
#define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1
#endif
#endif
#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED (CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED)
#ifndef CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED
#define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED
#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED
#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1
#endif
#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED (CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER)
#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && defined(__CUDA_ARCH__)
#define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1
#endif
#ifndef CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED
#define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED
#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER
#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1
#endif
#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && defined(__CUDA_ARCH__)
#define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1
#endif
// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need
@ -85,7 +89,7 @@ namespace cute
{
/// CUTE helper to cast SMEM pointer to unsigned
CUTE_HOST_DEVICE
CUTE_DEVICE
uint32_t
cast_smem_ptr_to_uint(void const* const ptr)
{