[Frontend][Core] Override HF config.json via CLI (#5836)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Krishna Mandal 2024-11-09 08:19:27 -08:00 committed by GitHub
parent d88bff1b96
commit b09895a618
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 73 additions and 53 deletions

View File

@ -200,8 +200,10 @@ def test_rope_customization():
trust_remote_code=False, trust_remote_code=False,
dtype="float16", dtype="float16",
seed=0, seed=0,
rope_scaling=TEST_ROPE_SCALING, hf_overrides={
rope_theta=TEST_ROPE_THETA, "rope_scaling": TEST_ROPE_SCALING,
"rope_theta": TEST_ROPE_THETA,
},
) )
assert getattr(llama_model_config.hf_config, "rope_scaling", assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING None) == TEST_ROPE_SCALING
@ -232,7 +234,9 @@ def test_rope_customization():
trust_remote_code=False, trust_remote_code=False,
dtype="float16", dtype="float16",
seed=0, seed=0,
rope_scaling=TEST_ROPE_SCALING, hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
},
) )
assert getattr(longchat_model_config.hf_config, "rope_scaling", assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING None) == TEST_ROPE_SCALING

View File

@ -1,5 +1,6 @@
import enum import enum
import json import json
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
@ -74,9 +75,6 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version. commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use branch name, a tag name, or a commit id. If unspecified, will use
the default version. the default version.
@ -116,6 +114,7 @@ class ModelConfig:
can not be gathered from the vllm arguments. can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded. config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'. Defaults to 'auto' which defaults to 'hf'.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor. for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding pooling_type: Used to configure the pooling method in the embedding
@ -146,7 +145,7 @@ class ModelConfig:
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None, rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
@ -164,6 +163,7 @@ class ModelConfig:
override_neuron_config: Optional[Dict[str, Any]] = None, override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO, config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string", chat_template_text_format: str = "string",
hf_overrides: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None, pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None, pooling_norm: Optional[bool] = None,
@ -178,8 +178,22 @@ class ModelConfig:
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.code_revision = code_revision self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta if hf_overrides is None:
hf_overrides = {}
if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides.update(hf_override)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
hf_overrides.update(hf_override)
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
# The tokenizer version is consistent with the model version by default. # The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None: if tokenizer_revision is None:
self.tokenizer_revision = revision self.tokenizer_revision = revision
@ -193,8 +207,8 @@ class ModelConfig:
self.disable_sliding_window = disable_sliding_window self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta, code_revision, config_format,
config_format) **hf_overrides)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.encoder_config = self._get_encoder_config() self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(

View File

@ -128,8 +128,9 @@ class EngineArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
hf_overrides: Optional[Dict[str, Any]] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: Optional[bool] = None enforce_eager: Optional[bool] = None
@ -140,8 +141,9 @@ class EngineArgs:
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
enable_lora: bool = False enable_lora: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
@ -187,7 +189,6 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
# Pooling configuration. # Pooling configuration.
@ -512,6 +513,12 @@ class EngineArgs:
help='RoPE theta. Use with `rope_scaling`. In ' help='RoPE theta. Use with `rope_scaling`. In '
'some cases, changing the RoPE theta improves the ' 'some cases, changing the RoPE theta improves the '
'performance of the scaled model.') 'performance of the scaled model.')
parser.add_argument('--hf-overrides',
type=json.loads,
default=EngineArgs.hf_overrides,
help='Extra arguments for the HuggingFace config.'
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument('--enforce-eager', parser.add_argument('--enforce-eager',
action='store_true', action='store_true',
help='Always use eager-mode PyTorch. If False, ' help='Always use eager-mode PyTorch. If False, '
@ -940,6 +947,7 @@ class EngineArgs:
code_revision=self.code_revision, code_revision=self.code_revision,
rope_scaling=self.rope_scaling, rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta, rope_theta=self.rope_theta,
hf_overrides=self.hf_overrides,
tokenizer_revision=self.tokenizer_revision, tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
quantization=self.quantization, quantization=self.quantization,

View File

@ -248,8 +248,7 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, " "model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, " "override_neuron_config=%s, tokenizer_revision=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, " "pipeline_parallel_size=%d, "
@ -271,8 +270,6 @@ class LLMEngine:
model_config.tokenizer_mode, model_config.tokenizer_mode,
model_config.revision, model_config.revision,
model_config.override_neuron_config, model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.trust_remote_code,
model_config.dtype, model_config.dtype,

View File

@ -98,7 +98,10 @@ class LLM:
to eager mode. Additionally for encoder-decoder models, if the to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall sequence length of the encoder input is larger than this, we fall
back to the eager mode. back to the eager mode.
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`) :ref:`engine_args`)
@ -153,6 +156,7 @@ class LLM:
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
hf_overrides: Optional[dict] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model` # After positional args are removed, move this right below `model`
task: TaskOption = "auto", task: TaskOption = "auto",
@ -194,6 +198,7 @@ class LLM:
max_seq_len_to_capture=max_seq_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type, pooling_type=pooling_type,
pooling_norm=pooling_norm, pooling_norm=pooling_norm,

View File

@ -146,9 +146,8 @@ def get_config(
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
config_format: ConfigFormat = ConfigFormat.AUTO, config_format: ConfigFormat = ConfigFormat.AUTO,
token: Optional[str] = None,
**kwargs, **kwargs,
) -> PretrainedConfig: ) -> PretrainedConfig:
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
@ -159,39 +158,43 @@ def get_config(
model = Path(model).parent model = Path(model).parent
if config_format == ConfigFormat.AUTO: if config_format == ConfigFormat.AUTO:
if is_gguf or file_or_path_exists(model, if is_gguf or file_or_path_exists(
HF_CONFIG_NAME, model, HF_CONFIG_NAME, revision=revision, token=token):
revision=revision,
token=kwargs.get("token")):
config_format = ConfigFormat.HF config_format = ConfigFormat.HF
elif file_or_path_exists(model, elif file_or_path_exists(model,
MISTRAL_CONFIG_NAME, MISTRAL_CONFIG_NAME,
revision=revision, revision=revision,
token=kwargs.get("token")): token=token):
config_format = ConfigFormat.MISTRAL config_format = ConfigFormat.MISTRAL
else: else:
# If we're in offline mode and found no valid config format, then # If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they # raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online. # don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists(). # This is conveniently triggered by calling file_exists().
file_exists(model, file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token"))
raise ValueError(f"No supported config format found in {model}") raise ValueError(f"No supported config format found in {model}")
if config_format == ConfigFormat.HF: if config_format == ConfigFormat.HF:
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model, revision=revision, code_revision=code_revision, **kwargs) model,
revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)
# Use custom model class if it's in our registry # Use custom model class if it's in our registry
model_type = config_dict.get("model_type") model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY: if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type] config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model, config = config_class.from_pretrained(
revision=revision, model,
code_revision=code_revision) revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)
else: else:
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
@ -199,6 +202,7 @@ def get_config(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=token,
**kwargs, **kwargs,
) )
except ValueError as e: except ValueError as e:
@ -216,7 +220,7 @@ def get_config(
raise e raise e
elif config_format == ConfigFormat.MISTRAL: elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, token=kwargs.get("token")) config = load_params_config(model, revision, token=token, **kwargs)
else: else:
raise ValueError(f"Unsupported config format: {config_format}") raise ValueError(f"Unsupported config format: {config_format}")
@ -228,19 +232,6 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]}) config.update({"architectures": [model_type]})
for key, value in [
("rope_scaling", rope_scaling),
("rope_theta", rope_theta),
]:
if value is not None:
logger.info(
"Updating %s from %r to %r",
key,
getattr(config, key, None),
value,
)
config.update({key: value})
patch_rope_scaling(config) patch_rope_scaling(config)
return config return config
@ -462,13 +453,15 @@ def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
def load_params_config(model: Union[str, Path], def load_params_config(model: Union[str, Path],
revision: Optional[str], revision: Optional[str],
token: Optional[str] = None) -> PretrainedConfig: token: Optional[str] = None,
**kwargs) -> PretrainedConfig:
# This function loads a params.json config which # This function loads a params.json config which
# should be used when loading models in mistral format # should be used when loading models in mistral format
config_file_name = "params.json" config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token) config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
assert isinstance(config_dict, dict)
config_mapping = { config_mapping = {
"dim": "hidden_size", "dim": "hidden_size",
@ -512,6 +505,8 @@ def load_params_config(model: Union[str, Path],
config_dict["architectures"] = ["PixtralForConditionalGeneration"] config_dict["architectures"] = ["PixtralForConditionalGeneration"]
config_dict["model_type"] = "pixtral" config_dict["model_type"] = "pixtral"
config_dict.update(kwargs)
config = recurse_elems(config_dict) config = recurse_elems(config_dict)
return config return config

View File

@ -74,8 +74,7 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, " "model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, " "override_neuron_config=%s, tokenizer_revision=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, " "pipeline_parallel_size=%d, "
@ -94,8 +93,6 @@ class LLMEngine:
model_config.tokenizer_mode, model_config.tokenizer_mode,
model_config.revision, model_config.revision,
model_config.override_neuron_config, model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.trust_remote_code,
model_config.dtype, model_config.dtype,