Use CUDA runtime API to retrieve function pointer to driver API (#1700)

* Query pfn to driver api

* use default for older toolkits

---------

Co-authored-by: shunfans <shunfans@nvidia.com>
This commit is contained in:
shunfan-shao 2024-08-19 10:26:09 -07:00 committed by GitHub
parent f93a69134e
commit 4dbf5dbed2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 78 additions and 2 deletions

View File

@ -234,6 +234,7 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API)
################################################################################
#

View File

@ -40,6 +40,8 @@
#include "cute/algorithm/prefetch.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/cuda_host_adapter.hpp"
namespace cute
{
@ -450,7 +452,7 @@ make_im2col_tma_copy_desc(
CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_);
CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle));
CUresult encode_result = cuTensorMapEncodeIm2col(
CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
&tma_desc,
tma_format,
num_total_modes,

View File

@ -41,6 +41,7 @@
#include <cute/algorithm/prefetch.hpp>
#include <cute/numeric/integral_ratio.hpp>
#include <cutlass/cuda_host_adapter.hpp>
namespace cute
{
@ -983,7 +984,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
// TMA smem swizzle type
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle));
CUresult result = cuTensorMapEncodeTiled(
CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tma_desc,
tma_format,
tma_dim,

View File

@ -82,6 +82,78 @@ namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(__CUDACC_RTC__)
#include <cudaTypedefs.h>
#include <driver_types.h>
#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok
#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
template <typename... Args> \
CUresult call_##func(Args... args) { \
return func(args...); \
}
#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
template <typename... Args> \
CUresult call_##func(Args... args) { \
cudaDriverEntryPointQueryResult cuda_status; \
void* pfn = nullptr; \
cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
&pfn, ver, \
cudaEnableDefault, \
&cuda_status); \
if (cuda_status != cudaDriverEntryPointSuccess || \
cuda_err != cudaSuccess) { \
return CUDA_ERROR_UNKNOWN; \
} \
return reinterpret_cast<PFN_##func##_v##ver>(pfn)(args...); \
}
#else
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
template <typename... Args> \
CUresult call_##func(Args... args) { \
cudaDriverEntryPointQueryResult cuda_status; \
void* pfn = nullptr; \
cudaError_t cuda_err = cudaGetDriverEntryPoint( \
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
&pfn, \
cudaEnableDefault, \
&cuda_status); \
if (cuda_status != cudaDriverEntryPointSuccess || \
cuda_err != cudaSuccess) { \
return CUDA_ERROR_UNKNOWN; \
} \
return reinterpret_cast<PFN_##func>(pfn)(args...); \
}
#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)
#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
#if (__CUDACC_VER_MAJOR__ >= 12)
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000);
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000);
#endif
#undef CUTLASS_CUDA_DRIVER_STRINGIFY
#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func
#endif // !defined(__CUDACC_RTC__)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter
/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute
/// is not introduced.