From 3f3b6b21500bce2061cae33706bd47c8b6663771 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 20 Jun 2024 14:36:10 -0400 Subject: [PATCH] [Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (#5715) --- csrc/ops.h | 2 ++ .../quantization/cutlass_w8a8/scaled_mm_entry.cu | 16 ++++++++++++++++ csrc/torch_bindings.cpp | 6 ++++++ vllm/_custom_ops.py | 4 ++++ vllm/model_executor/layers/quantization/fp8.py | 13 +------------ 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index ba92cc53..6f0a7143 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); + void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 687f8efd..f4e582d7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales); #endif +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { + // CUTLASS FP8 kernels need at least + // CUDA 12.0 on SM90 systems (Hopper) + // CUDA 12.4 on SM89 systems (Lovelace) + +#if defined CUDA_VERSION + if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12000; + } else if (cuda_device_capability >= 89) { + return CUDA_VERSION >= 12040; + } +#endif + + return false; +} + void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 953f2eb4..227b69d7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -144,6 +144,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm); + + // Check if cutlass scaled_mm is supported for CUDA devices of the given + // capability + ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, + &cutlass_scaled_mm_supports_fp8); #endif // Quantized GEMM for GPTQ. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a053a3aa..e050c117 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -216,6 +216,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype]) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e89fd658..bbf3cde5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -20,19 +20,8 @@ logger = init_logger(__name__) def cutlass_fp8_supported() -> bool: capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] - major, minor = torch.version.cuda.split(".") - version = int(major) * 10 + int(minor) - # CUTLASS FP8 kernels need at least - # CUDA 12.0 on SM90 systems (Hopper) - # CUDA 12.4 on SM89 systems (Lovelace) - gpu_is_supported = False - if capability >= 90: - gpu_is_supported = version > 120 - elif capability >= 89: - gpu_is_supported = version > 124 - - return gpu_is_supported + return ops.cutlass_scaled_mm_supports_fp8(capability) class Fp8Config(QuantizationConfig):