diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 462ba8a7..cae68221 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -179,7 +179,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, - a_scales: torch.Tensor, b_scales: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype]) -> torch.Tensor: assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) @@ -188,7 +188,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) + vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) return out diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bf3a59e3..136a6462 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -17,6 +17,24 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) +def cutlass_fp8_supported() -> bool: + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + version = torch.version.cuda + version = version[0] * 10 + version[1] + + # CUTLASS FP8 kernels need at least + # CUDA 12.0 on SM90 systems (Hopper) + # CUDA 12.4 on SM89 systems (Lovelace) + gpu_is_supported = False + if capability >= 900: + gpu_is_supported = version > 120 + elif capability >= 890: + gpu_is_supported = version > 124 + + return gpu_is_supported + + class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -92,6 +110,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() def _create_scale_param( self, @@ -233,25 +252,40 @@ class Fp8LinearMethod(LinearMethodBase): layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.act_scale is None and x_scale computed from x. # If static, layer.act_scale is scalar and x_scale set to act_scale. - qinput, x_scale = ops.scaled_fp8_quant(x, - layer.act_scale, - batch_dim_padding=17) - # Fused GEMM_DQ -- note we padded the input above because - # torch._scaled_mm is more performant for matrices with - # batch dimension > 16. Note that this could change - # in the future. - output, _ = torch._scaled_mm( - qinput, - layer.weight, - out_dtype=x.dtype, - scale_a=x_scale, - scale_b=layer.weight_scale, - bias=bias, - ) + if bias is None and self.cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm_dq( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + ) + + else: + qinput, x_scale = ops.scaled_fp8_quant(x, + layer.act_scale, + batch_dim_padding=17) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. + output, _ = torch._scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + bias=bias, + ) return torch.narrow(output, 0, 0, x.shape[0])