Enable shared memory intrinsics and ldmatrix PTX on Clang. (#754)
* Enable shared memory intrinsics and ldmatrix PTX on Clang. This commit adds preprocessor checks to enable the shared memory intrinsics `__cvta_generic_to_shared` and `__nvvm_get_smem_pointer`, as well as the `ldmatrix` PTX instructions, on Clang. Preventing these intrinsics from being used is a significant latency regression on Clang. * refine the macro --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
660a05f581
commit
ecbd24566c
@ -35,8 +35,29 @@
|
|||||||
#include <cute/arch/copy.hpp>
|
#include <cute/arch/copy.hpp>
|
||||||
|
|
||||||
// Config
|
// Config
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
|
#if defined(__clang__) && defined(__CUDA__)
|
||||||
# define CUTE_ARCH_LDSM_SM75_ENABLED
|
// ldmatrix PTX instructions added in Clang 14: https://reviews.llvm.org/D107046
|
||||||
|
// ... but broken until Clang 15:
|
||||||
|
// * https://reviews.llvm.org/D121666
|
||||||
|
// * https://reviews.llvm.org/D126846
|
||||||
|
#define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__NVCC__) || defined(__CUDACC_RTC__)
|
||||||
|
// ldmatrix PTX instruction added in CUDA 10.2+
|
||||||
|
#define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED)
|
||||||
|
#define CUTE_ARCH_LDSM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if ! defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
||||||
|
#define CUTE_ARCH_LDSM_SM75_ENABLED (CUTE_ARCH_LDSM_SM75_SUPPORTED)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||||
|
#define CUTE_ARCH_LDSM_SM75_ACTIVATED 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace cute
|
namespace cute
|
||||||
@ -51,13 +72,13 @@ struct SM75_U32x1_LDSM_N
|
|||||||
copy(uint128_t const& smem_src,
|
copy(uint128_t const& smem_src,
|
||||||
uint32_t& dst)
|
uint32_t& dst)
|
||||||
{
|
{
|
||||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||||
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
|
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
|
||||||
: "=r"(dst)
|
: "=r"(dst)
|
||||||
: "r"(smem_int_ptr));
|
: "r"(smem_int_ptr));
|
||||||
#else
|
#else
|
||||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -71,13 +92,13 @@ struct SM75_U32x2_LDSM_N
|
|||||||
copy(uint128_t const& smem_src,
|
copy(uint128_t const& smem_src,
|
||||||
uint32_t& dst0, uint32_t& dst1)
|
uint32_t& dst0, uint32_t& dst1)
|
||||||
{
|
{
|
||||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||||
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||||
: "=r"(dst0), "=r"(dst1)
|
: "=r"(dst0), "=r"(dst1)
|
||||||
: "r"(smem_int_ptr));
|
: "r"(smem_int_ptr));
|
||||||
#else
|
#else
|
||||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -91,13 +112,13 @@ struct SM75_U32x4_LDSM_N
|
|||||||
copy(uint128_t const& smem_src,
|
copy(uint128_t const& smem_src,
|
||||||
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
||||||
{
|
{
|
||||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||||
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
||||||
: "r"(smem_int_ptr));
|
: "r"(smem_int_ptr));
|
||||||
#else
|
#else
|
||||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -111,13 +132,13 @@ struct SM75_U16x2_LDSM_T
|
|||||||
copy(uint128_t const& smem_src,
|
copy(uint128_t const& smem_src,
|
||||||
uint32_t& dst)
|
uint32_t& dst)
|
||||||
{
|
{
|
||||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||||
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
|
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
|
||||||
: "=r"(dst)
|
: "=r"(dst)
|
||||||
: "r"(smem_int_ptr));
|
: "r"(smem_int_ptr));
|
||||||
#else
|
#else
|
||||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -131,13 +152,13 @@ struct SM75_U16x4_LDSM_T
|
|||||||
copy(uint128_t const& smem_src,
|
copy(uint128_t const& smem_src,
|
||||||
uint32_t& dst0, uint32_t& dst1)
|
uint32_t& dst0, uint32_t& dst1)
|
||||||
{
|
{
|
||||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||||
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||||
: "=r"(dst0), "=r"(dst1)
|
: "=r"(dst0), "=r"(dst1)
|
||||||
: "r"(smem_int_ptr));
|
: "r"(smem_int_ptr));
|
||||||
#else
|
#else
|
||||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -151,13 +172,13 @@ struct SM75_U16x8_LDSM_T
|
|||||||
copy(uint128_t const& smem_src,
|
copy(uint128_t const& smem_src,
|
||||||
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3)
|
||||||
{
|
{
|
||||||
#if defined(CUTE_ARCH_LDSM_SM75_ENABLED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src);
|
||||||
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
: "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3)
|
||||||
: "r"(smem_int_ptr));
|
: "r"(smem_int_ptr));
|
||||||
#else
|
#else
|
||||||
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ENABLED.");
|
CUTE_RUNTIME_ASSERT("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -34,7 +34,42 @@
|
|||||||
|
|
||||||
#include <cute/numeric/integer_sequence.hpp>
|
#include <cute/numeric/integer_sequence.hpp>
|
||||||
|
|
||||||
#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
#if defined(__clang__) && defined(__CUDA__)
|
||||||
|
// __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665
|
||||||
|
#define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED (__clang_major__ >= 14)
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
// __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665
|
||||||
|
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 14)
|
||||||
|
#else
|
||||||
|
// ... but broken on Windows until Clang 15: https://reviews.llvm.org/D122897
|
||||||
|
#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 15)
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__NVCC__) || defined(__CUDACC_RTC__)
|
||||||
|
// __cvta_generic_to_shared added in CUDA 11+
|
||||||
|
#define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11))
|
||||||
|
|
||||||
|
// __nvvm_get_smem_pointer added in CUDA 10.2
|
||||||
|
#define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED (CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED)
|
||||||
|
|
||||||
|
#ifndef CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED
|
||||||
|
#define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED (CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER)
|
||||||
|
|
||||||
|
#ifndef CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED
|
||||||
|
#define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need
|
||||||
|
// to provide one for NVCC
|
||||||
|
#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER
|
||||||
extern "C" {
|
extern "C" {
|
||||||
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
||||||
// Clients should not call it directly.
|
// Clients should not call it directly.
|
||||||
@ -52,7 +87,7 @@ cast_smem_ptr_to_uint(void const* const ptr)
|
|||||||
{
|
{
|
||||||
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
|
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
|
||||||
// the previous internal intrinsics if they are available.
|
// the previous internal intrinsics if they are available.
|
||||||
#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
|
#if CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED
|
||||||
//
|
//
|
||||||
// This NVVM intrinsic converts an address in shared memory to a plain
|
// This NVVM intrinsic converts an address in shared memory to a plain
|
||||||
// unsigned integer. This is necessary to pass to shared memory instructions
|
// unsigned integer. This is necessary to pass to shared memory instructions
|
||||||
@ -65,7 +100,7 @@ cast_smem_ptr_to_uint(void const* const ptr)
|
|||||||
/// CUTE helper to get SMEM pointer
|
/// CUTE helper to get SMEM pointer
|
||||||
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||||
|
|
||||||
#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
#elif CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED
|
||||||
|
|
||||||
return __nvvm_get_smem_pointer(ptr);
|
return __nvvm_get_smem_pointer(ptr);
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@
|
|||||||
|
|
||||||
#include "cutlass/array.h"
|
#include "cutlass/array.h"
|
||||||
#include "cutlass/layout/matrix.h"
|
#include "cutlass/layout/matrix.h"
|
||||||
|
#include "cute/arch/copy_sm75.hpp"
|
||||||
#include "cute/arch/util.hpp"
|
#include "cute/arch/util.hpp"
|
||||||
|
|
||||||
namespace cutlass {
|
namespace cutlass {
|
||||||
@ -57,17 +58,6 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
|
|||||||
//
|
//
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11)
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
|
|
||||||
#define CUDA_LDMATRIX_ACTIVATED 1
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define CUDA_LDMATRIX_SUPPORTED 1
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/// CUTLASS helper to get SMEM pointer
|
/// CUTLASS helper to get SMEM pointer
|
||||||
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
|
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
|
||||||
return cute::cast_smem_ptr_to_uint(ptr);
|
return cute::cast_smem_ptr_to_uint(ptr);
|
||||||
@ -85,7 +75,7 @@ inline __device__ void ldsm<layout::RowMajor, 1>(
|
|||||||
Array<unsigned, 1> & D,
|
Array<unsigned, 1> & D,
|
||||||
void const* ptr) {
|
void const* ptr) {
|
||||||
|
|
||||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
|
|
||||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||||
|
|
||||||
@ -109,7 +99,7 @@ inline __device__ void ldsm<layout::RowMajor, 2>(
|
|||||||
Array<unsigned, 2> & D,
|
Array<unsigned, 2> & D,
|
||||||
void const* ptr) {
|
void const* ptr) {
|
||||||
|
|
||||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
|
|
||||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||||
|
|
||||||
@ -133,7 +123,7 @@ inline __device__ void ldsm<layout::RowMajor, 4>(
|
|||||||
Array<unsigned, 4> & D,
|
Array<unsigned, 4> & D,
|
||||||
void const* ptr) {
|
void const* ptr) {
|
||||||
|
|
||||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
|
|
||||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||||
|
|
||||||
@ -161,7 +151,7 @@ inline __device__ void ldsm<layout::ColumnMajor, 1>(
|
|||||||
Array<unsigned, 1> & D,
|
Array<unsigned, 1> & D,
|
||||||
void const* ptr) {
|
void const* ptr) {
|
||||||
|
|
||||||
#if CUDA_LDMATRIX_ACTIVATED
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
|
|
||||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||||
|
|
||||||
@ -185,7 +175,7 @@ inline __device__ void ldsm<layout::ColumnMajor, 2>(
|
|||||||
Array<unsigned, 2> & D,
|
Array<unsigned, 2> & D,
|
||||||
void const* ptr) {
|
void const* ptr) {
|
||||||
|
|
||||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
|
|
||||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||||
|
|
||||||
@ -209,7 +199,7 @@ inline __device__ void ldsm<layout::ColumnMajor, 4>(
|
|||||||
Array<unsigned, 4> & D,
|
Array<unsigned, 4> & D,
|
||||||
void const* ptr) {
|
void const* ptr) {
|
||||||
|
|
||||||
#if defined(CUDA_LDMATRIX_ACTIVATED)
|
#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED)
|
||||||
|
|
||||||
unsigned addr = cutlass_get_smem_pointer(ptr);
|
unsigned addr = cutlass_get_smem_pointer(ptr);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user