[Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (#5715)
This commit is contained in:
parent
a7dcc62086
commit
3f3b6b2150
@ -92,6 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
|
||||
@ -144,6 +144,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
|
||||
&cutlass_scaled_mm_supports_fp8);
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
|
||||
@ -216,6 +216,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
|
||||
|
||||
# cutlass
|
||||
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype]) -> torch.Tensor:
|
||||
|
||||
@ -20,19 +20,8 @@ logger = init_logger(__name__)
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
major, minor = torch.version.cuda.split(".")
|
||||
version = int(major) * 10 + int(minor)
|
||||
|
||||
# CUTLASS FP8 kernels need at least
|
||||
# CUDA 12.0 on SM90 systems (Hopper)
|
||||
# CUDA 12.4 on SM89 systems (Lovelace)
|
||||
gpu_is_supported = False
|
||||
if capability >= 90:
|
||||
gpu_is_supported = version > 120
|
||||
elif capability >= 89:
|
||||
gpu_is_supported = version > 124
|
||||
|
||||
return gpu_is_supported
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user