[Kernel] Implement fallback for FP8 channelwise using torch._scaled_mm (#6552)
This commit is contained in:
parent
f53b8f0d05
commit
4ffffccb7e
@ -23,16 +23,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
# On Lovelace, fail for now if channelwise.
|
|
||||||
# TODO: (@tms) fallback
|
|
||||||
if (not self.cutlass_fp8_supported
|
|
||||||
and self.strategy == QuantizationStrategy.CHANNEL):
|
|
||||||
raise ValueError(
|
|
||||||
"Channelwise fp8 quantization requires vLLM's custom "
|
|
||||||
"cutlass kernels, which are not supported on your device."
|
|
||||||
"Consider quantizing with per tensor scales or upgrading "
|
|
||||||
"to Hopper.")
|
|
||||||
|
|
||||||
def get_min_capability(self) -> int:
|
def get_min_capability(self) -> int:
|
||||||
# lovelace and up
|
# lovelace and up
|
||||||
return 89
|
return 89
|
||||||
@ -53,7 +43,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
|
|
||||||
# If channelwise, scales are already lined up, so just transpose.
|
# If channelwise, scales are already lined up, so just transpose.
|
||||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
assert self.cutlass_fp8_supported
|
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
|
||||||
|
|||||||
@ -124,20 +124,50 @@ def apply_fp8_linear(
|
|||||||
bias=bias)
|
bias=bias)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# Note: we pad the input because torch._scaled_mm is more performant
|
||||||
|
# for matrices with batch dimension > 16.
|
||||||
|
# This could change in the future.
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(input,
|
qinput, x_scale = ops.scaled_fp8_quant(input,
|
||||||
input_scale,
|
input_scale,
|
||||||
batch_dim_padding=17)
|
batch_dim_padding=17)
|
||||||
|
|
||||||
# Fused GEMM_DQ -- note we padded the input above because
|
if weight_scale.numel() == 1:
|
||||||
# torch._scaled_mm is more performant for matrices with
|
# Fused GEMM_DQ
|
||||||
# batch dimension > 16. Note that this could change
|
output, _ = torch._scaled_mm(qinput,
|
||||||
# in the future.
|
weight,
|
||||||
output, _ = torch._scaled_mm(qinput,
|
out_dtype=input.dtype,
|
||||||
weight,
|
scale_a=x_scale,
|
||||||
out_dtype=input.dtype,
|
scale_b=weight_scale,
|
||||||
scale_a=x_scale,
|
bias=bias)
|
||||||
scale_b=weight_scale,
|
else:
|
||||||
bias=bias)
|
# Fallback for channelwise case, where the weight scales are
|
||||||
|
# applied separately.
|
||||||
|
|
||||||
|
# Symmetric quantized GEMM by definition computes the following:
|
||||||
|
# C = (s_x * X) (s_w * W) + bias
|
||||||
|
# This is equivalent to dequantizing the weights and activations
|
||||||
|
# before applying a GEMM.
|
||||||
|
#
|
||||||
|
# In order to compute quantized operands, a quantized kernel
|
||||||
|
# will rewrite the above like so:
|
||||||
|
# C = s_w * s_x * (X * W) + bias
|
||||||
|
#
|
||||||
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
|
# This computes C = sx * (X * W).
|
||||||
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
|
output, _ = torch._scaled_mm(qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=torch.float32,
|
||||||
|
scale_a=x_scale)
|
||||||
|
|
||||||
|
# C = sw * sx * (X * W)
|
||||||
|
output = output * weight_scale.t()
|
||||||
|
if bias is not None:
|
||||||
|
# C = sw * sx * (X * W) + bias
|
||||||
|
output = output + bias
|
||||||
|
output = output.to(dtype=input.dtype)
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, input.shape[0])
|
return torch.narrow(output, 0, 0, input.shape[0])
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user