[Build] Guard against older CUDA versions when building CUTLASS 3.x kernels (#5168)
This commit is contained in:
parent
657579113f
commit
1197e02141
@ -1,3 +1,9 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
@ -6,8 +12,6 @@
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
@ -241,3 +245,5 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
||||
#else
|
||||
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user