Use quantization_config in hf config (#1695)
This commit is contained in:
parent
e87557b069
commit
bb00f66e19
@ -104,14 +104,30 @@ class ModelConfig:
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
if self.quantization is None:
|
||||
return
|
||||
quantization = self.quantization.lower()
|
||||
if quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization: {self.quantization}. Must be one of "
|
||||
f"{supported_quantization}.")
|
||||
self.quantization = quantization
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if hf_quant_config is not None:
|
||||
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||
if self.quantization is None:
|
||||
self.quantization = hf_quant_method
|
||||
elif self.quantization != hf_quant_method:
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({hf_quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization}).")
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
logger.warning(f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
|
||||
@ -66,6 +66,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.hf_config,
|
||||
model_config.download_dir)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
@ -7,9 +7,10 @@ from collections import defaultdict
|
||||
from typing import Any, Iterator, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file(
|
||||
def get_quant_config(
|
||||
quantization: str,
|
||||
model_name_or_path: str,
|
||||
hf_config: PretrainedConfig,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(quantization)
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(hf_config, "quantization_config", None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
@ -98,7 +106,6 @@ def get_quant_config(
|
||||
hf_folder = model_name_or_path
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_cls = get_quantization_config(quantization)
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user