[Frontend][Core] Override HF config.json via CLI (#5836)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d88bff1b96
commit
b09895a618
@ -200,8 +200,10 @@ def test_rope_customization():
|
|||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
seed=0,
|
seed=0,
|
||||||
rope_scaling=TEST_ROPE_SCALING,
|
hf_overrides={
|
||||||
rope_theta=TEST_ROPE_THETA,
|
"rope_scaling": TEST_ROPE_SCALING,
|
||||||
|
"rope_theta": TEST_ROPE_THETA,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
assert getattr(llama_model_config.hf_config, "rope_scaling",
|
assert getattr(llama_model_config.hf_config, "rope_scaling",
|
||||||
None) == TEST_ROPE_SCALING
|
None) == TEST_ROPE_SCALING
|
||||||
@ -232,7 +234,9 @@ def test_rope_customization():
|
|||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
seed=0,
|
seed=0,
|
||||||
rope_scaling=TEST_ROPE_SCALING,
|
hf_overrides={
|
||||||
|
"rope_scaling": TEST_ROPE_SCALING,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||||
None) == TEST_ROPE_SCALING
|
None) == TEST_ROPE_SCALING
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
||||||
Mapping, Optional, Set, Tuple, Type, Union)
|
Mapping, Optional, Set, Tuple, Type, Union)
|
||||||
@ -74,9 +75,6 @@ class ModelConfig:
|
|||||||
code_revision: The specific revision to use for the model code on
|
code_revision: The specific revision to use for the model code on
|
||||||
Hugging Face Hub. It can be a branch name, a tag name, or a
|
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||||
commit id. If unspecified, will use the default version.
|
commit id. If unspecified, will use the default version.
|
||||||
rope_scaling: Dictionary containing the scaling configuration for the
|
|
||||||
RoPE embeddings. When using this flag, don't update
|
|
||||||
`max_position_embeddings` to the expected new maximum.
|
|
||||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
branch name, a tag name, or a commit id. If unspecified, will use
|
branch name, a tag name, or a commit id. If unspecified, will use
|
||||||
the default version.
|
the default version.
|
||||||
@ -116,6 +114,7 @@ class ModelConfig:
|
|||||||
can not be gathered from the vllm arguments.
|
can not be gathered from the vllm arguments.
|
||||||
config_format: The config format which shall be loaded.
|
config_format: The config format which shall be loaded.
|
||||||
Defaults to 'auto' which defaults to 'hf'.
|
Defaults to 'auto' which defaults to 'hf'.
|
||||||
|
hf_overrides: Arguments to be forwarded to the HuggingFace config.
|
||||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||||
for multi-modal data, e.g., image processor.
|
for multi-modal data, e.g., image processor.
|
||||||
pooling_type: Used to configure the pooling method in the embedding
|
pooling_type: Used to configure the pooling method in the embedding
|
||||||
@ -146,7 +145,7 @@ class ModelConfig:
|
|||||||
allowed_local_media_path: str = "",
|
allowed_local_media_path: str = "",
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
rope_scaling: Optional[dict] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
rope_theta: Optional[float] = None,
|
rope_theta: Optional[float] = None,
|
||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
@ -164,6 +163,7 @@ class ModelConfig:
|
|||||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
chat_template_text_format: str = "string",
|
chat_template_text_format: str = "string",
|
||||||
|
hf_overrides: Optional[Dict[str, Any]] = None,
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
pooling_type: Optional[str] = None,
|
pooling_type: Optional[str] = None,
|
||||||
pooling_norm: Optional[bool] = None,
|
pooling_norm: Optional[bool] = None,
|
||||||
@ -178,8 +178,22 @@ class ModelConfig:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.code_revision = code_revision
|
self.code_revision = code_revision
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.rope_theta = rope_theta
|
if hf_overrides is None:
|
||||||
|
hf_overrides = {}
|
||||||
|
if rope_scaling is not None:
|
||||||
|
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
|
||||||
|
hf_overrides.update(hf_override)
|
||||||
|
msg = ("`--rope-scaling` will be removed in a future release. "
|
||||||
|
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||||
|
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||||
|
if rope_theta is not None:
|
||||||
|
hf_override = {"rope_theta": rope_theta}
|
||||||
|
hf_overrides.update(hf_override)
|
||||||
|
msg = ("`--rope-theta` will be removed in a future release. "
|
||||||
|
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
|
||||||
|
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||||
|
|
||||||
# The tokenizer version is consistent with the model version by default.
|
# The tokenizer version is consistent with the model version by default.
|
||||||
if tokenizer_revision is None:
|
if tokenizer_revision is None:
|
||||||
self.tokenizer_revision = revision
|
self.tokenizer_revision = revision
|
||||||
@ -193,8 +207,8 @@ class ModelConfig:
|
|||||||
self.disable_sliding_window = disable_sliding_window
|
self.disable_sliding_window = disable_sliding_window
|
||||||
self.skip_tokenizer_init = skip_tokenizer_init
|
self.skip_tokenizer_init = skip_tokenizer_init
|
||||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
code_revision, rope_scaling, rope_theta,
|
code_revision, config_format,
|
||||||
config_format)
|
**hf_overrides)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.encoder_config = self._get_encoder_config()
|
self.encoder_config = self._get_encoder_config()
|
||||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||||
|
|||||||
@ -128,8 +128,9 @@ class EngineArgs:
|
|||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
code_revision: Optional[str] = None
|
code_revision: Optional[str] = None
|
||||||
rope_scaling: Optional[dict] = None
|
rope_scaling: Optional[Dict[str, Any]] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
|
hf_overrides: Optional[Dict[str, Any]] = None
|
||||||
tokenizer_revision: Optional[str] = None
|
tokenizer_revision: Optional[str] = None
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
enforce_eager: Optional[bool] = None
|
enforce_eager: Optional[bool] = None
|
||||||
@ -140,8 +141,9 @@ class EngineArgs:
|
|||||||
# is intended for expert use only. The API may change without
|
# is intended for expert use only. The API may change without
|
||||||
# notice.
|
# notice.
|
||||||
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
|
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
|
||||||
tokenizer_pool_extra_config: Optional[dict] = None
|
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
|
||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
max_loras: int = 1
|
max_loras: int = 1
|
||||||
max_lora_rank: int = 16
|
max_lora_rank: int = 16
|
||||||
@ -187,7 +189,6 @@ class EngineArgs:
|
|||||||
collect_detailed_traces: Optional[str] = None
|
collect_detailed_traces: Optional[str] = None
|
||||||
disable_async_output_proc: bool = False
|
disable_async_output_proc: bool = False
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
|
||||||
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
|
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
|
||||||
|
|
||||||
# Pooling configuration.
|
# Pooling configuration.
|
||||||
@ -512,6 +513,12 @@ class EngineArgs:
|
|||||||
help='RoPE theta. Use with `rope_scaling`. In '
|
help='RoPE theta. Use with `rope_scaling`. In '
|
||||||
'some cases, changing the RoPE theta improves the '
|
'some cases, changing the RoPE theta improves the '
|
||||||
'performance of the scaled model.')
|
'performance of the scaled model.')
|
||||||
|
parser.add_argument('--hf-overrides',
|
||||||
|
type=json.loads,
|
||||||
|
default=EngineArgs.hf_overrides,
|
||||||
|
help='Extra arguments for the HuggingFace config.'
|
||||||
|
'This should be a JSON string that will be '
|
||||||
|
'parsed into a dictionary.')
|
||||||
parser.add_argument('--enforce-eager',
|
parser.add_argument('--enforce-eager',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Always use eager-mode PyTorch. If False, '
|
help='Always use eager-mode PyTorch. If False, '
|
||||||
@ -940,6 +947,7 @@ class EngineArgs:
|
|||||||
code_revision=self.code_revision,
|
code_revision=self.code_revision,
|
||||||
rope_scaling=self.rope_scaling,
|
rope_scaling=self.rope_scaling,
|
||||||
rope_theta=self.rope_theta,
|
rope_theta=self.rope_theta,
|
||||||
|
hf_overrides=self.hf_overrides,
|
||||||
tokenizer_revision=self.tokenizer_revision,
|
tokenizer_revision=self.tokenizer_revision,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
quantization=self.quantization,
|
quantization=self.quantization,
|
||||||
|
|||||||
@ -248,8 +248,7 @@ class LLMEngine:
|
|||||||
"Initializing an LLM engine (v%s) with config: "
|
"Initializing an LLM engine (v%s) with config: "
|
||||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||||
"override_neuron_config=%s, "
|
"override_neuron_config=%s, tokenizer_revision=%s, "
|
||||||
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
|
||||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||||
"pipeline_parallel_size=%d, "
|
"pipeline_parallel_size=%d, "
|
||||||
@ -271,8 +270,6 @@ class LLMEngine:
|
|||||||
model_config.tokenizer_mode,
|
model_config.tokenizer_mode,
|
||||||
model_config.revision,
|
model_config.revision,
|
||||||
model_config.override_neuron_config,
|
model_config.override_neuron_config,
|
||||||
model_config.rope_scaling,
|
|
||||||
model_config.rope_theta,
|
|
||||||
model_config.tokenizer_revision,
|
model_config.tokenizer_revision,
|
||||||
model_config.trust_remote_code,
|
model_config.trust_remote_code,
|
||||||
model_config.dtype,
|
model_config.dtype,
|
||||||
|
|||||||
@ -98,7 +98,10 @@ class LLM:
|
|||||||
to eager mode. Additionally for encoder-decoder models, if the
|
to eager mode. Additionally for encoder-decoder models, if the
|
||||||
sequence length of the encoder input is larger than this, we fall
|
sequence length of the encoder input is larger than this, we fall
|
||||||
back to the eager mode.
|
back to the eager mode.
|
||||||
disable_custom_all_reduce: See ParallelConfig
|
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
|
||||||
|
disable_async_output_proc: Disable async output processing.
|
||||||
|
This may result in lower performance.
|
||||||
|
hf_overrides: Arguments to be forwarded to the HuggingFace config.
|
||||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||||
:ref:`engine_args`)
|
:ref:`engine_args`)
|
||||||
|
|
||||||
@ -153,6 +156,7 @@ class LLM:
|
|||||||
max_seq_len_to_capture: int = 8192,
|
max_seq_len_to_capture: int = 8192,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
disable_async_output_proc: bool = False,
|
disable_async_output_proc: bool = False,
|
||||||
|
hf_overrides: Optional[dict] = None,
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
# After positional args are removed, move this right below `model`
|
# After positional args are removed, move this right below `model`
|
||||||
task: TaskOption = "auto",
|
task: TaskOption = "auto",
|
||||||
@ -194,6 +198,7 @@ class LLM:
|
|||||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
disable_async_output_proc=disable_async_output_proc,
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
|
hf_overrides=hf_overrides,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
pooling_type=pooling_type,
|
pooling_type=pooling_type,
|
||||||
pooling_norm=pooling_norm,
|
pooling_norm=pooling_norm,
|
||||||
|
|||||||
@ -146,9 +146,8 @@ def get_config(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
rope_scaling: Optional[dict] = None,
|
|
||||||
rope_theta: Optional[float] = None,
|
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
|
token: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> PretrainedConfig:
|
) -> PretrainedConfig:
|
||||||
# Separate model folder from file path for GGUF models
|
# Separate model folder from file path for GGUF models
|
||||||
@ -159,39 +158,43 @@ def get_config(
|
|||||||
model = Path(model).parent
|
model = Path(model).parent
|
||||||
|
|
||||||
if config_format == ConfigFormat.AUTO:
|
if config_format == ConfigFormat.AUTO:
|
||||||
if is_gguf or file_or_path_exists(model,
|
if is_gguf or file_or_path_exists(
|
||||||
HF_CONFIG_NAME,
|
model, HF_CONFIG_NAME, revision=revision, token=token):
|
||||||
revision=revision,
|
|
||||||
token=kwargs.get("token")):
|
|
||||||
config_format = ConfigFormat.HF
|
config_format = ConfigFormat.HF
|
||||||
elif file_or_path_exists(model,
|
elif file_or_path_exists(model,
|
||||||
MISTRAL_CONFIG_NAME,
|
MISTRAL_CONFIG_NAME,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
token=kwargs.get("token")):
|
token=token):
|
||||||
config_format = ConfigFormat.MISTRAL
|
config_format = ConfigFormat.MISTRAL
|
||||||
else:
|
else:
|
||||||
# If we're in offline mode and found no valid config format, then
|
# If we're in offline mode and found no valid config format, then
|
||||||
# raise an offline mode error to indicate to the user that they
|
# raise an offline mode error to indicate to the user that they
|
||||||
# don't have files cached and may need to go online.
|
# don't have files cached and may need to go online.
|
||||||
# This is conveniently triggered by calling file_exists().
|
# This is conveniently triggered by calling file_exists().
|
||||||
file_exists(model,
|
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
|
||||||
HF_CONFIG_NAME,
|
|
||||||
revision=revision,
|
|
||||||
token=kwargs.get("token"))
|
|
||||||
|
|
||||||
raise ValueError(f"No supported config format found in {model}")
|
raise ValueError(f"No supported config format found in {model}")
|
||||||
|
|
||||||
if config_format == ConfigFormat.HF:
|
if config_format == ConfigFormat.HF:
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model, revision=revision, code_revision=code_revision, **kwargs)
|
model,
|
||||||
|
revision=revision,
|
||||||
|
code_revision=code_revision,
|
||||||
|
token=token,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Use custom model class if it's in our registry
|
# Use custom model class if it's in our registry
|
||||||
model_type = config_dict.get("model_type")
|
model_type = config_dict.get("model_type")
|
||||||
if model_type in _CONFIG_REGISTRY:
|
if model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[model_type]
|
config_class = _CONFIG_REGISTRY[model_type]
|
||||||
config = config_class.from_pretrained(model,
|
config = config_class.from_pretrained(
|
||||||
revision=revision,
|
model,
|
||||||
code_revision=code_revision)
|
revision=revision,
|
||||||
|
code_revision=code_revision,
|
||||||
|
token=token,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
@ -199,6 +202,7 @@ def get_config(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
|
token=token,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -216,7 +220,7 @@ def get_config(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif config_format == ConfigFormat.MISTRAL:
|
elif config_format == ConfigFormat.MISTRAL:
|
||||||
config = load_params_config(model, revision, token=kwargs.get("token"))
|
config = load_params_config(model, revision, token=token, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported config format: {config_format}")
|
raise ValueError(f"Unsupported config format: {config_format}")
|
||||||
|
|
||||||
@ -228,19 +232,6 @@ def get_config(
|
|||||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||||
config.update({"architectures": [model_type]})
|
config.update({"architectures": [model_type]})
|
||||||
|
|
||||||
for key, value in [
|
|
||||||
("rope_scaling", rope_scaling),
|
|
||||||
("rope_theta", rope_theta),
|
|
||||||
]:
|
|
||||||
if value is not None:
|
|
||||||
logger.info(
|
|
||||||
"Updating %s from %r to %r",
|
|
||||||
key,
|
|
||||||
getattr(config, key, None),
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
config.update({key: value})
|
|
||||||
|
|
||||||
patch_rope_scaling(config)
|
patch_rope_scaling(config)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
@ -462,13 +453,15 @@ def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
|||||||
|
|
||||||
def load_params_config(model: Union[str, Path],
|
def load_params_config(model: Union[str, Path],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
token: Optional[str] = None) -> PretrainedConfig:
|
token: Optional[str] = None,
|
||||||
|
**kwargs) -> PretrainedConfig:
|
||||||
# This function loads a params.json config which
|
# This function loads a params.json config which
|
||||||
# should be used when loading models in mistral format
|
# should be used when loading models in mistral format
|
||||||
|
|
||||||
config_file_name = "params.json"
|
config_file_name = "params.json"
|
||||||
|
|
||||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
|
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
|
||||||
|
assert isinstance(config_dict, dict)
|
||||||
|
|
||||||
config_mapping = {
|
config_mapping = {
|
||||||
"dim": "hidden_size",
|
"dim": "hidden_size",
|
||||||
@ -512,6 +505,8 @@ def load_params_config(model: Union[str, Path],
|
|||||||
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
|
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
|
||||||
config_dict["model_type"] = "pixtral"
|
config_dict["model_type"] = "pixtral"
|
||||||
|
|
||||||
|
config_dict.update(kwargs)
|
||||||
|
|
||||||
config = recurse_elems(config_dict)
|
config = recurse_elems(config_dict)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
@ -74,8 +74,7 @@ class LLMEngine:
|
|||||||
"Initializing an LLM engine (v%s) with config: "
|
"Initializing an LLM engine (v%s) with config: "
|
||||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||||
"override_neuron_config=%s, "
|
"override_neuron_config=%s, tokenizer_revision=%s, "
|
||||||
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
|
||||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||||
"pipeline_parallel_size=%d, "
|
"pipeline_parallel_size=%d, "
|
||||||
@ -94,8 +93,6 @@ class LLMEngine:
|
|||||||
model_config.tokenizer_mode,
|
model_config.tokenizer_mode,
|
||||||
model_config.revision,
|
model_config.revision,
|
||||||
model_config.override_neuron_config,
|
model_config.override_neuron_config,
|
||||||
model_config.rope_scaling,
|
|
||||||
model_config.rope_theta,
|
|
||||||
model_config.tokenizer_revision,
|
model_config.tokenizer_revision,
|
||||||
model_config.trust_remote_code,
|
model_config.trust_remote_code,
|
||||||
model_config.dtype,
|
model_config.dtype,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user