[AMD][ROCm]Quantization methods on ROCm; Fix _scaled_mm call (#8380)
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
e18749ff09
commit
b3195bc9e4
@ -255,7 +255,10 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = [*QUANTIZATION_METHODS]
|
supported_quantization = [*QUANTIZATION_METHODS]
|
||||||
rocm_supported_quantization = ["awq", "gptq", "fp8"]
|
rocm_supported_quantization = [
|
||||||
|
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||||
|
"fbgemm_fp8"
|
||||||
|
]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||||
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
||||||
|
|||||||
@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
QuantizationStrategy)
|
QuantizationStrategy)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
|
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
||||||
|
requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||||
|
|
||||||
@ -39,16 +41,37 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
logical_widths=layer.logical_widths,
|
logical_widths=layer.logical_widths,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=max_w_scale,
|
||||||
|
input_scale=layer.input_scale)
|
||||||
|
if input_scale is not None:
|
||||||
|
layer.input_scale = Parameter(input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
|
|
||||||
# If channelwise, scales are already lined up, so just transpose.
|
# If channelwise, scales are already lined up, so just transpose.
|
||||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
weight, weight_scale, input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=layer.input_scale)
|
||||||
|
if input_scale is not None:
|
||||||
|
layer.input_scale = Parameter(input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
weight_scale = layer.weight_scale.data
|
||||||
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
# required by torch.compile to be torch.nn.Parameter
|
# required by torch.compile to be torch.nn.Parameter
|
||||||
layer.weight_scale = Parameter(layer.weight_scale.data,
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||||
|
|||||||
@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear)
|
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter)
|
ModelWeightParameter)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -125,8 +126,18 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
weight, weight_scale, input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=None)
|
||||||
|
if input_scale is not None:
|
||||||
|
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||||
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
|
|
||||||
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
if self.quant_config.use_marlin:
|
if self.quant_config.use_marlin:
|
||||||
prepare_fp8_layer_for_marlin(layer)
|
prepare_fp8_layer_for_marlin(layer)
|
||||||
# Activations not quantized for marlin.
|
# Activations not quantized for marlin.
|
||||||
|
|||||||
@ -6,11 +6,9 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
# scaled_mm in pytorch on rocm has a bug that requires always
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# providing scaling factor for result. This value is created
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
# as global value to avoid multiple tensor allocations, and
|
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
|
||||||
# can be removed once pytorch fixes the bug.
|
|
||||||
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
|
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_supported() -> bool:
|
def cutlass_fp8_supported() -> bool:
|
||||||
@ -131,19 +129,17 @@ def apply_fp8_linear(
|
|||||||
|
|
||||||
if per_tensor_weights and per_tensor_activations:
|
if per_tensor_weights and per_tensor_activations:
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
output = torch._scaled_mm(
|
output = torch._scaled_mm(qinput,
|
||||||
qinput,
|
|
||||||
weight,
|
weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
|
|
||||||
bias=bias)
|
bias=bias)
|
||||||
# Since in torch 2.5, scaled_mm only returns single value
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
# This should be removed when vllm-nvidia also moves to 2.5
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
if is_hip():
|
if type(output) is tuple and len(output) == 2:
|
||||||
return torch.narrow(output, 0, 0, input.shape[0])
|
|
||||||
return torch.narrow(output[0], 0, 0, input.shape[0])
|
return torch.narrow(output[0], 0, 0, input.shape[0])
|
||||||
|
return torch.narrow(output, 0, 0, input.shape[0])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback for channelwise case, where we use unfused DQ
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
@ -161,12 +157,23 @@ def apply_fp8_linear(
|
|||||||
# For the scaled_mm fallback case, we break this down, since it
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
# does not support s_w being a vector.
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
|
# Making sure the dummy tensor is on the same device as the weight
|
||||||
|
global TORCH_DEVICE_IDENTITY
|
||||||
|
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||||
|
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||||
|
|
||||||
# GEMM
|
# GEMM
|
||||||
# This computes C = (X * W).
|
# This computes C = (X * W).
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
output, _ = torch._scaled_mm(qinput,
|
output = torch._scaled_mm(qinput,
|
||||||
weight,
|
weight,
|
||||||
|
scale_a=TORCH_DEVICE_IDENTITY,
|
||||||
|
scale_b=TORCH_DEVICE_IDENTITY,
|
||||||
out_dtype=torch.float32)
|
out_dtype=torch.float32)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
# Unpad (undo num_token_padding)
|
# Unpad (undo num_token_padding)
|
||||||
output = torch.narrow(output, 0, 0, input.shape[0])
|
output = torch.narrow(output, 0, 0, input.shape[0])
|
||||||
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
|
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user