From bb00f66e19acdf6cb614683ab74f777ed3932eee Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 17 Nov 2023 16:23:49 -0800 Subject: [PATCH] Use `quantization_config` in hf config (#1695) --- vllm/config.py | 32 +++++++++++++++++++++-------- vllm/model_executor/model_loader.py | 1 + vllm/model_executor/weight_utils.py | 11 ++++++++-- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fda9a268..955f091d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -104,14 +104,30 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = ["awq", "squeezellm"] - if self.quantization is None: - return - quantization = self.quantization.lower() - if quantization not in supported_quantization: - raise ValueError( - f"Unknown quantization: {self.quantization}. Must be one of " - f"{supported_quantization}.") - self.quantization = quantization + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + hf_quant_config = getattr(self.hf_config, "quantization_config", None) + if hf_quant_config is not None: + hf_quant_method = str(hf_quant_config["quant_method"]).lower() + if self.quantization is None: + self.quantization = hf_quant_method + elif self.quantization != hf_quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({hf_quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + logger.warning(f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.") def verify_with_parallel_config( self, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index aad369b1..71a22c77 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -66,6 +66,7 @@ def get_model(model_config: ModelConfig) -> nn.Module: if model_config.quantization is not None: quant_config = get_quant_config(model_config.quantization, model_config.model, + model_config.hf_config, model_config.download_dir) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index a91eaab5..eec43e29 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -7,9 +7,10 @@ from collections import defaultdict from typing import Any, Iterator, List, Optional, Tuple from huggingface_hub import snapshot_download -from safetensors.torch import load_file, save_file, safe_open import numpy as np +from safetensors.torch import load_file, save_file, safe_open import torch +from transformers import PretrainedConfig from tqdm.auto import tqdm from vllm.logger import init_logger @@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file( def get_quant_config( quantization: str, model_name_or_path: str, + hf_config: PretrainedConfig, cache_dir: Optional[str] = None, ) -> QuantizationConfig: + quant_cls = get_quantization_config(quantization) + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(hf_config, "quantization_config", None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. @@ -98,7 +106,6 @@ def get_quant_config( hf_folder = model_name_or_path config_files = glob.glob(os.path.join(hf_folder, "*.json")) - quant_cls = get_quantization_config(quantization) quant_config_files = [ f for f in config_files if any( f.endswith(x) for x in quant_cls.get_config_filenames())