[4/N] make quant config first-class citizen (#9978)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
ac6b8f19b9
commit
8d72bb20fa
@ -23,9 +23,13 @@ if TYPE_CHECKING:
|
|||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||||
BaseTokenizerGroup)
|
BaseTokenizerGroup)
|
||||||
|
else:
|
||||||
|
QuantizationConfig = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -1966,6 +1970,35 @@ class VllmConfig:
|
|||||||
decoding_config: Optional[DecodingConfig] = None
|
decoding_config: Optional[DecodingConfig] = None
|
||||||
observability_config: Optional[ObservabilityConfig] = None
|
observability_config: Optional[ObservabilityConfig] = None
|
||||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||||
|
quant_config: Optional[QuantizationConfig] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_quantization_config(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||||
|
"""Get the quantization config."""
|
||||||
|
if model_config.quantization is not None:
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
get_quant_config)
|
||||||
|
quant_config = get_quant_config(model_config, load_config)
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
|
||||||
|
if capability_tuple is not None:
|
||||||
|
capability = capability_tuple.to_int()
|
||||||
|
if capability < quant_config.get_min_capability():
|
||||||
|
raise ValueError(
|
||||||
|
f"The quantization method {model_config.quantization} "
|
||||||
|
"is not supported for the current GPU. Minimum "
|
||||||
|
f"capability: {quant_config.get_min_capability()}. "
|
||||||
|
f"Current capability: {capability}.")
|
||||||
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||||
|
if model_config.dtype not in supported_dtypes:
|
||||||
|
raise ValueError(
|
||||||
|
f"{model_config.dtype} is not supported for quantization "
|
||||||
|
f"method {model_config.quantization}. Supported dtypes: "
|
||||||
|
f"{supported_dtypes}")
|
||||||
|
return quant_config
|
||||||
|
return None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Verify configs are valid & consistent with each other.
|
"""Verify configs are valid & consistent with each other.
|
||||||
@ -1983,3 +2016,8 @@ class VllmConfig:
|
|||||||
if self.prompt_adapter_config:
|
if self.prompt_adapter_config:
|
||||||
self.prompt_adapter_config.verify_with_model_config(
|
self.prompt_adapter_config.verify_with_model_config(
|
||||||
self.model_config)
|
self.model_config)
|
||||||
|
|
||||||
|
if self.quant_config is None and \
|
||||||
|
self.model_config is not None and self.load_config is not None:
|
||||||
|
self.quant_config = VllmConfig._get_quantization_config(
|
||||||
|
self.model_config, self.load_config)
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||||
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
|
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
|
||||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
||||||
safetensors_weights_iterator)
|
safetensors_weights_iterator)
|
||||||
from vllm.model_executor.models import (has_inner_state, supports_lora,
|
from vllm.model_executor.models import (has_inner_state, supports_lora,
|
||||||
@ -93,32 +93,6 @@ def device_loading_context(module: torch.nn.Module,
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_quantization_config(
|
|
||||||
model_config: ModelConfig,
|
|
||||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
|
||||||
"""Get the quantization config."""
|
|
||||||
if model_config.quantization is not None:
|
|
||||||
quant_config = get_quant_config(model_config, load_config)
|
|
||||||
capability_tuple = current_platform.get_device_capability()
|
|
||||||
|
|
||||||
if capability_tuple is not None:
|
|
||||||
capability = capability_tuple.to_int()
|
|
||||||
if capability < quant_config.get_min_capability():
|
|
||||||
raise ValueError(
|
|
||||||
f"The quantization method {model_config.quantization} "
|
|
||||||
"is not supported for the current GPU. "
|
|
||||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
|
||||||
f"Current capability: {capability}.")
|
|
||||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
|
||||||
if model_config.dtype not in supported_dtypes:
|
|
||||||
raise ValueError(
|
|
||||||
f"{model_config.dtype} is not supported for quantization "
|
|
||||||
f"method {model_config.quantization}. Supported dtypes: "
|
|
||||||
f"{supported_dtypes}")
|
|
||||||
return quant_config
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_model_initialization_kwargs(
|
def _get_model_initialization_kwargs(
|
||||||
model_class: Type[nn.Module],
|
model_class: Type[nn.Module],
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
@ -185,7 +159,6 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
|
|||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
load_config = vllm_config.load_config
|
|
||||||
model_class, _ = get_model_architecture(model_config)
|
model_class, _ = get_model_architecture(model_config)
|
||||||
|
|
||||||
return build_model(
|
return build_model(
|
||||||
@ -193,7 +166,7 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
|
|||||||
vllm_config,
|
vllm_config,
|
||||||
model_config.hf_config,
|
model_config.hf_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=_get_quantization_config(model_config, load_config),
|
quant_config=vllm_config.quant_config,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
multimodal_config=model_config.multimodal_config,
|
multimodal_config=model_config.multimodal_config,
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
@ -518,8 +491,7 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model_class = get_model_architecture(model_config)[0]
|
model_class = get_model_architecture(model_config)[0]
|
||||||
quant_config = _get_quantization_config(
|
quant_config = vllm_config.quant_config
|
||||||
model_config, self.load_config)
|
|
||||||
extra_kwargs = _get_model_initialization_kwargs(
|
extra_kwargs = _get_model_initialization_kwargs(
|
||||||
model_class, lora_config, model_config.multimodal_config)
|
model_class, lora_config, model_config.multimodal_config)
|
||||||
extra_kwargs["quant_config"] = quant_config
|
extra_kwargs["quant_config"] = quant_config
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user