Enable shared memory intrinsics and ldmatrix PTX on Clang. (#754)

* Enable shared memory intrinsics and ldmatrix PTX on Clang.

This commit adds preprocessor checks to enable the shared memory
intrinsics `__cvta_generic_to_shared` and `__nvvm_get_smem_pointer`, as
well as the `ldmatrix` PTX instructions, on Clang. Preventing these
intrinsics from being used is a significant latency regression on Clang.

* refine the macro

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Gregory Meyer (gregjm) 2023-03-31 18:42:24 -07:00 committed by GitHub
parent 660a05f581
commit ecbd24566c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 34 deletions

View File

@ -35,8 +35,29 @@
#include <cute/arch/copy.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
# define CUTE_ARCH_LDSM_SM75_ENABLED
#if defined(__clang__) && defined(__CUDA__)
// ldmatrix PTX instructions added in Clang 14: https://reviews.llvm.org/D107046
// ... but broken until Clang 15:
// * https://reviews.llvm.org/D121666
// * https://reviews.llvm.org/D126846
#define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15)
#endif
#if defined(__NVCC__) || defined(__CUDACC_RTC__)
// ldmatrix PTX instruction added in CUDA 10.2+
#define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11)
#endif
#if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED)
#define CUTE_ARCH_LDSM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75)
#endif
#if ! defined(CUTE_ARCH_LDSM_SM75_ENABLED)
#define CUTE_ARCH_LDSM_SM75_ENABLED (CUTE_ARCH_LDSM_SM75_SUPPORTED)
#endif
#if (CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#define CUTE_ARCH_LDSM_SM75_ACTIVATED 1
#endif
namespace cute
@ -51,13 +72,13 @@ struct SM75_U32x1_LDSM_N
copy(uint128_t const& smem_src,
uint32_t& dst)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -71,13 +92,13 @@ struct SM75_U32x2_LDSM_N
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -91,13 +112,13 @@ struct SM75_U32x4_LDSM_N
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -111,13 +132,13 @@ struct SM75_U16x2_LDSM_T
copy(uint128_t const& smem_src,
uint32_t& dst)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -131,13 +152,13 @@ struct SM75_U16x4_LDSM_T
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst0), "=r"(dst1)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};
@ -151,13 +172,13 @@ struct SM75_U16x8_LDSM_T
copy(uint128_t const& smem_src,
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
{
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
: "r"(smem_int_ptr));
#else
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
#endif
}
};

View File

@ -34,7 +34,42 @@
#include <cute/numeric/integer_sequence.hpp>
#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
#if defined(__clang__) && defined(__CUDA__)
// __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665
#define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED (__clang_major__ >= 14)
#ifndef _WIN32
// __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 14)
#else
// ... but broken on Windows until Clang 15: https://reviews.llvm.org/D122897
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 15)
#endif
#endif
#if defined(__NVCC__) || defined(__CUDACC_RTC__)
// __cvta_generic_to_shared added in CUDA 11+
#define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11))
// __nvvm_get_smem_pointer added in CUDA 10.2
#define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
#endif
#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED (CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED)
#ifndef CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED
#define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED
#endif
#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED (CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER)
#ifndef CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED
#define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED
#endif
// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need
// to provide one for NVCC
#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER
extern "C" {
// This NVVM intrinsic is subject to change in future versions of CUDA.
// Clients should not call it directly.
@ -52,7 +87,7 @@ cast_smem_ptr_to_uint(void const* const ptr)
{
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
// the previous internal intrinsics if they are available.
#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
#if CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED
//
// This NVVM intrinsic converts an address in shared memory to a plain
// unsigned integer. This is necessary to pass to shared memory instructions
@ -65,7 +100,7 @@ cast_smem_ptr_to_uint(void const* const ptr)
/// CUTE helper to get SMEM pointer
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
#elif CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED
return __nvvm_get_smem_pointer(ptr);

View File

@ -36,6 +36,7 @@
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cute/arch/copy_sm75.hpp"
#include "cute/arch/util.hpp"
namespace cutlass {
@ -57,17 +58,6 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
//
/////////////////////////////////////////////////////////////////////////////////////////////////
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
#define CUDA_LDMATRIX_ACTIVATED 1
#endif
#define CUDA_LDMATRIX_SUPPORTED 1
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
return cute::cast_smem_ptr_to_uint(ptr);
@ -85,7 +75,7 @@ inline __device__ void ldsm<layout::RowMajor, 1>(
Array<unsigned, 1> & D,
void const* ptr) {
#if defined(CUDA_LDMATRIX_ACTIVATED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
unsigned addr = cutlass_get_smem_pointer(ptr);
@ -109,7 +99,7 @@ inline __device__ void ldsm<layout::RowMajor, 2>(
Array<unsigned, 2> & D,
void const* ptr) {
#if defined(CUDA_LDMATRIX_ACTIVATED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
unsigned addr = cutlass_get_smem_pointer(ptr);
@ -133,7 +123,7 @@ inline __device__ void ldsm<layout::RowMajor, 4>(
Array<unsigned, 4> & D,
void const* ptr) {
#if defined(CUDA_LDMATRIX_ACTIVATED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
unsigned addr = cutlass_get_smem_pointer(ptr);
@ -161,7 +151,7 @@ inline __device__ void ldsm<layout::ColumnMajor, 1>(
Array<unsigned, 1> & D,
void const* ptr) {
#if CUDA_LDMATRIX_ACTIVATED
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
unsigned addr = cutlass_get_smem_pointer(ptr);
@ -185,7 +175,7 @@ inline __device__ void ldsm<layout::ColumnMajor, 2>(
Array<unsigned, 2> & D,
void const* ptr) {
#if defined(CUDA_LDMATRIX_ACTIVATED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
unsigned addr = cutlass_get_smem_pointer(ptr);
@ -209,7 +199,7 @@ inline __device__ void ldsm<layout::ColumnMajor, 4>(
Array<unsigned, 4> & D,
void const* ptr) {
#if defined(CUDA_LDMATRIX_ACTIVATED)
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
unsigned addr = cutlass_get_smem_pointer(ptr);