Use CUDA runtime API to retrieve function pointer to driver API (#1700)
* Query pfn to driver api * use default for older toolkits --------- Co-authored-by: shunfans <shunfans@nvidia.com>
This commit is contained in:
parent
f93a69134e
commit
4dbf5dbed2
@ -234,6 +234,7 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
|
||||
set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API)
|
||||
|
||||
################################################################################
|
||||
#
|
||||
|
||||
@ -40,6 +40,8 @@
|
||||
|
||||
#include "cute/algorithm/prefetch.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/cuda_host_adapter.hpp"
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
@ -450,7 +452,7 @@ make_im2col_tma_copy_desc(
|
||||
CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_);
|
||||
CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle));
|
||||
|
||||
CUresult encode_result = cuTensorMapEncodeIm2col(
|
||||
CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
|
||||
&tma_desc,
|
||||
tma_format,
|
||||
num_total_modes,
|
||||
|
||||
@ -41,6 +41,7 @@
|
||||
#include <cute/algorithm/prefetch.hpp>
|
||||
|
||||
#include <cute/numeric/integral_ratio.hpp>
|
||||
#include <cutlass/cuda_host_adapter.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -983,7 +984,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
|
||||
|
||||
// TMA smem swizzle type
|
||||
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle));
|
||||
CUresult result = cuTensorMapEncodeTiled(
|
||||
CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
|
||||
&tma_desc,
|
||||
tma_format,
|
||||
tma_dim,
|
||||
|
||||
@ -82,6 +82,78 @@ namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
|
||||
#include <cudaTypedefs.h>
|
||||
#include <driver_types.h>
|
||||
|
||||
#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok
|
||||
|
||||
#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||
|
||||
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
|
||||
template <typename... Args> \
|
||||
CUresult call_##func(Args... args) { \
|
||||
return func(args...); \
|
||||
}
|
||||
|
||||
#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)
|
||||
|
||||
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
|
||||
template <typename... Args> \
|
||||
CUresult call_##func(Args... args) { \
|
||||
cudaDriverEntryPointQueryResult cuda_status; \
|
||||
void* pfn = nullptr; \
|
||||
cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \
|
||||
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
|
||||
&pfn, ver, \
|
||||
cudaEnableDefault, \
|
||||
&cuda_status); \
|
||||
if (cuda_status != cudaDriverEntryPointSuccess || \
|
||||
cuda_err != cudaSuccess) { \
|
||||
return CUDA_ERROR_UNKNOWN; \
|
||||
} \
|
||||
return reinterpret_cast<PFN_##func##_v##ver>(pfn)(args...); \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
|
||||
template <typename... Args> \
|
||||
CUresult call_##func(Args... args) { \
|
||||
cudaDriverEntryPointQueryResult cuda_status; \
|
||||
void* pfn = nullptr; \
|
||||
cudaError_t cuda_err = cudaGetDriverEntryPoint( \
|
||||
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
|
||||
&pfn, \
|
||||
cudaEnableDefault, \
|
||||
&cuda_status); \
|
||||
if (cuda_status != cudaDriverEntryPointSuccess || \
|
||||
cuda_err != cudaSuccess) { \
|
||||
return CUDA_ERROR_UNKNOWN; \
|
||||
} \
|
||||
return reinterpret_cast<PFN_##func>(pfn)(args...); \
|
||||
}
|
||||
|
||||
#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)
|
||||
|
||||
#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000);
|
||||
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000);
|
||||
#endif
|
||||
|
||||
#undef CUTLASS_CUDA_DRIVER_STRINGIFY
|
||||
|
||||
#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func
|
||||
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter
|
||||
/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute
|
||||
/// is not introduced.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user