Use CUDA runtime API to retrieve function pointer to driver API (#1700)
* Query pfn to driver api * use default for older toolkits --------- Co-authored-by: shunfans <shunfans@nvidia.com>
This commit is contained in:
parent
f93a69134e
commit
4dbf5dbed2
@ -234,6 +234,7 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries
|
|||||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||||
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
|
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
|
||||||
|
set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
#
|
#
|
||||||
|
|||||||
@ -40,6 +40,8 @@
|
|||||||
|
|
||||||
#include "cute/algorithm/prefetch.hpp"
|
#include "cute/algorithm/prefetch.hpp"
|
||||||
#include "cutlass/fast_math.h"
|
#include "cutlass/fast_math.h"
|
||||||
|
#include "cutlass/cuda_host_adapter.hpp"
|
||||||
|
|
||||||
namespace cute
|
namespace cute
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -450,7 +452,7 @@ make_im2col_tma_copy_desc(
|
|||||||
CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_);
|
CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_);
|
||||||
CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle));
|
CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle));
|
||||||
|
|
||||||
CUresult encode_result = cuTensorMapEncodeIm2col(
|
CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
|
||||||
&tma_desc,
|
&tma_desc,
|
||||||
tma_format,
|
tma_format,
|
||||||
num_total_modes,
|
num_total_modes,
|
||||||
|
|||||||
@ -41,6 +41,7 @@
|
|||||||
#include <cute/algorithm/prefetch.hpp>
|
#include <cute/algorithm/prefetch.hpp>
|
||||||
|
|
||||||
#include <cute/numeric/integral_ratio.hpp>
|
#include <cute/numeric/integral_ratio.hpp>
|
||||||
|
#include <cutlass/cuda_host_adapter.hpp>
|
||||||
|
|
||||||
namespace cute
|
namespace cute
|
||||||
{
|
{
|
||||||
@ -983,7 +984,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin
|
|||||||
|
|
||||||
// TMA smem swizzle type
|
// TMA smem swizzle type
|
||||||
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle));
|
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle));
|
||||||
CUresult result = cuTensorMapEncodeTiled(
|
CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
|
||||||
&tma_desc,
|
&tma_desc,
|
||||||
tma_format,
|
tma_format,
|
||||||
tma_dim,
|
tma_dim,
|
||||||
|
|||||||
@ -82,6 +82,78 @@ namespace cutlass {
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#if !defined(__CUDACC_RTC__)
|
||||||
|
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
#include <driver_types.h>
|
||||||
|
|
||||||
|
#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||||
|
|
||||||
|
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
|
||||||
|
template <typename... Args> \
|
||||||
|
CUresult call_##func(Args... args) { \
|
||||||
|
return func(args...); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||||
|
|
||||||
|
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)
|
||||||
|
|
||||||
|
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
|
||||||
|
template <typename... Args> \
|
||||||
|
CUresult call_##func(Args... args) { \
|
||||||
|
cudaDriverEntryPointQueryResult cuda_status; \
|
||||||
|
void* pfn = nullptr; \
|
||||||
|
cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \
|
||||||
|
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
|
||||||
|
&pfn, ver, \
|
||||||
|
cudaEnableDefault, \
|
||||||
|
&cuda_status); \
|
||||||
|
if (cuda_status != cudaDriverEntryPointSuccess || \
|
||||||
|
cuda_err != cudaSuccess) { \
|
||||||
|
return CUDA_ERROR_UNKNOWN; \
|
||||||
|
} \
|
||||||
|
return reinterpret_cast<PFN_##func##_v##ver>(pfn)(args...); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
|
||||||
|
template <typename... Args> \
|
||||||
|
CUresult call_##func(Args... args) { \
|
||||||
|
cudaDriverEntryPointQueryResult cuda_status; \
|
||||||
|
void* pfn = nullptr; \
|
||||||
|
cudaError_t cuda_err = cudaGetDriverEntryPoint( \
|
||||||
|
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
|
||||||
|
&pfn, \
|
||||||
|
cudaEnableDefault, \
|
||||||
|
&cuda_status); \
|
||||||
|
if (cuda_status != cudaDriverEntryPointSuccess || \
|
||||||
|
cuda_err != cudaSuccess) { \
|
||||||
|
return CUDA_ERROR_UNKNOWN; \
|
||||||
|
} \
|
||||||
|
return reinterpret_cast<PFN_##func>(pfn)(args...); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)
|
||||||
|
|
||||||
|
#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)
|
||||||
|
|
||||||
|
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||||
|
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000);
|
||||||
|
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#undef CUTLASS_CUDA_DRIVER_STRINGIFY
|
||||||
|
|
||||||
|
#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func
|
||||||
|
|
||||||
|
#endif // !defined(__CUDACC_RTC__)
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter
|
/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter
|
||||||
/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute
|
/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute
|
||||||
/// is not introduced.
|
/// is not introduced.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user