[Kernel] Switch fp8 layers to use the CUTLASS kernels (#5183)
Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8 see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and #5144 for comparisons across different GEMM sizes.
This commit is contained in:
parent
388596c914
commit
8d75fe48ca
@ -179,7 +179,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
|
|
||||||
# cutlass
|
# cutlass
|
||||||
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
||||||
a_scales: torch.Tensor, b_scales: torch.Tensor,
|
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||||
out_dtype: Type[torch.dtype]) -> torch.Tensor:
|
out_dtype: Type[torch.dtype]) -> torch.Tensor:
|
||||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||||
@ -188,7 +188,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
|
|||||||
n = b.shape[1]
|
n = b.shape[1]
|
||||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||||
|
|
||||||
vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
|
vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,24 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_fp8_supported() -> bool:
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
version = torch.version.cuda
|
||||||
|
version = version[0] * 10 + version[1]
|
||||||
|
|
||||||
|
# CUTLASS FP8 kernels need at least
|
||||||
|
# CUDA 12.0 on SM90 systems (Hopper)
|
||||||
|
# CUDA 12.4 on SM89 systems (Lovelace)
|
||||||
|
gpu_is_supported = False
|
||||||
|
if capability >= 900:
|
||||||
|
gpu_is_supported = version > 120
|
||||||
|
elif capability >= 890:
|
||||||
|
gpu_is_supported = version > 124
|
||||||
|
|
||||||
|
return gpu_is_supported
|
||||||
|
|
||||||
|
|
||||||
class Fp8Config(QuantizationConfig):
|
class Fp8Config(QuantizationConfig):
|
||||||
"""Config class for FP8."""
|
"""Config class for FP8."""
|
||||||
|
|
||||||
@ -92,6 +110,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
def _create_scale_param(
|
def _create_scale_param(
|
||||||
self,
|
self,
|
||||||
@ -233,25 +252,40 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
||||||
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
|
||||||
layer.act_scale,
|
|
||||||
batch_dim_padding=17)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ -- note we padded the input above because
|
if bias is None and self.cutlass_fp8_supported:
|
||||||
# torch._scaled_mm is more performant for matrices with
|
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
|
||||||
# batch dimension > 16. Note that this could change
|
|
||||||
# in the future.
|
# Fused GEMM_DQ
|
||||||
output, _ = torch._scaled_mm(
|
output = ops.cutlass_scaled_mm_dq(
|
||||||
qinput,
|
qinput,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
out_dtype=x.dtype,
|
out_dtype=x.dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=layer.weight_scale,
|
scale_b=layer.weight_scale,
|
||||||
bias=bias,
|
)
|
||||||
)
|
|
||||||
|
else:
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||||
|
layer.act_scale,
|
||||||
|
batch_dim_padding=17)
|
||||||
|
|
||||||
|
# Fused GEMM_DQ -- note we padded the input above because
|
||||||
|
# torch._scaled_mm is more performant for matrices with
|
||||||
|
# batch dimension > 16. Note that this could change
|
||||||
|
# in the future.
|
||||||
|
output, _ = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
layer.weight,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=layer.weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, x.shape[0])
|
return torch.narrow(output, 0, 0, x.shape[0])
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user