From 4dbf5dbed2331b948b75a3dbeaf760d76b3b5964 Mon Sep 17 00:00:00 2001 From: shunfan-shao <79347016+shunfan-shao@users.noreply.github.com> Date: Mon, 19 Aug 2024 10:26:09 -0700 Subject: [PATCH] Use CUDA runtime API to retrieve function pointer to driver API (#1700) * Query pfn to driver api * use default for older toolkits --------- Co-authored-by: shunfans --- CMakeLists.txt | 1 + include/cute/atom/copy_traits_sm90_im2col.hpp | 4 +- include/cute/atom/copy_traits_sm90_tma.hpp | 3 +- include/cutlass/cuda_host_adapter.hpp | 72 +++++++++++++++++++ 4 files changed, 78 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ac67eb86..4e1ffd75 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,7 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.") +set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API) ################################################################################ # diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index f6c9e258..ad4f8675 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -40,6 +40,8 @@ #include "cute/algorithm/prefetch.hpp" #include "cutlass/fast_math.h" +#include "cutlass/cuda_host_adapter.hpp" + namespace cute { @@ -450,7 +452,7 @@ make_im2col_tma_copy_desc( CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle)); - CUresult encode_result = cuTensorMapEncodeIm2col( + CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( &tma_desc, tma_format, num_total_modes, diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 950855a1..2238c418 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -41,6 +41,7 @@ #include #include +#include namespace cute { @@ -983,7 +984,7 @@ make_tma_copy_desc(Tensor const& gtensor, // The origin // TMA smem swizzle type CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); - CUresult result = cuTensorMapEncodeTiled( + CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( &tma_desc, tma_format, tma_dim, diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index 28f5ae0e..f9ff723c 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -82,6 +82,78 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) + +#include +#include + +#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok + +#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + return func(args...); \ + } + +#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, ver, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#else + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPoint( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) + +#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if (__CUDACC_VER_MAJOR__ >= 12) +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000); +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000); +#endif + +#undef CUTLASS_CUDA_DRIVER_STRINGIFY + +#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func + +#endif // !defined(__CUDACC_RTC__) + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter /// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute /// is not introduced.