[Hardware][NV] Add support for ModelOpt static scaling checkpoints. (#6112)

This commit is contained in:
Pavani Majety 2024-09-10 21:38:40 -07:00 committed by GitHub
parent 1230263e16
commit efcf946a15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 258 additions and 6 deletions

View File

@ -1,6 +1,6 @@
### Quantizer Utilities
`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM:
`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py`
`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported
from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py)
### Prerequisite

View File

@ -0,0 +1,79 @@
# flake8: noqa
"""Tests Model Optimizer fp8 models against ground truth generation
Note: these tests will only pass on H100
"""
import os
from typing import List
import pytest
from transformers import AutoTokenizer
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024
MODELS = ["nvidia/Llama-3.1-8B-Instruct-FP8"]
EXPECTED_STRS_MAP = {
"nvidia/Llama-3.1-8B-Instruct-FP8": [
"You're referring to VLLM, a high-performance Large Language Model (LLM) inference and",
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir',
'The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる'
]
}
# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp8 implementation or
# the hardware being run on.
# Disabled to prevent it from breaking the build
@pytest.mark.skip(
reason=
"Prevent unstable test based on golden strings from breaking the build.")
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(example_prompts, model_name) -> None:
model = LLM(
model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True,
quantization="modelopt",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
tokenize=False,
add_generation_prompt=True)
for prompt in example_prompts
]
params = SamplingParams(max_tokens=20, temperature=0)
generations: List[str] = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
outputs = model.generate(prompt, params)
generations.append(outputs[0].outputs[0].text)
del model
print(model_name, generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")

View File

@ -282,9 +282,9 @@ class ModelConfig:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["awq", "gptq", "fp8"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
"experts_int8"
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]

View File

@ -26,7 +26,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod"
]

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig
@ -34,6 +35,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,

View File

@ -0,0 +1,163 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
logger = init_logger(__name__)
ACTIVATION_SCHEMES = ["static"]
class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change.")
@classmethod
def get_name(cls) -> str:
return "modelopt"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
if not is_checkpoint_fp8_serialized:
raise ValueError("ModelOpt currently only supports static FP8"
"quantization in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
return cls(is_checkpoint_fp8_serialized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
activation scale. Future support might be added for dynamic
scales.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn datatype
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)

View File

@ -192,6 +192,13 @@ def get_quant_config(model_config: ModelConfig,
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"
f" found for {model_config.quantization} in {f}.")
return quant_cls.from_config(config)