From e69ded7d1c8a4f6ed26e64090bdc050c06cde3b9 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Fri, 7 Jun 2024 17:42:05 -0700 Subject: [PATCH] [Bug Fix] Fix the support check for FP8 CUTLASS (#5352) Bug description: With torch 2.4.0.dev20240603+cu121, cutlass_fp8_supported outputs False, and the (capability, version) before the comparison is (90, 11111111112) This PR fixes the support check for FP8 CUTLASS ( cutlass_fp8_supported) which was introduced in https://github.com/vllm-project/vllm/pull/5183. --- vllm/model_executor/layers/quantization/fp8.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 136a6462..de94bad7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -20,16 +20,16 @@ logger = init_logger(__name__) def cutlass_fp8_supported() -> bool: capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] - version = torch.version.cuda - version = version[0] * 10 + version[1] + major, minor = torch.version.cuda.split(".") + version = int(major) * 10 + int(minor) # CUTLASS FP8 kernels need at least # CUDA 12.0 on SM90 systems (Hopper) # CUDA 12.4 on SM89 systems (Lovelace) gpu_is_supported = False - if capability >= 900: + if capability >= 90: gpu_is_supported = version > 120 - elif capability >= 890: + elif capability >= 89: gpu_is_supported = version > 124 return gpu_is_supported @@ -103,7 +103,7 @@ class Fp8LinearMethod(LinearMethodBase): 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) - + Args: quant_config: The quantization config. """ @@ -298,8 +298,8 @@ class Fp8KVCacheMethod(QuantizeMethodBase): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module): - """Create "weight" (aka kv_scale) for an attention layer. - + """Create "weight" (aka kv_scale) for an attention layer. + Args: layer: The layer that is using the QuantizeMethodBase factory. """