support ignore patterns in model loader (#6673)
This commit is contained in:
parent
22fa2e35cb
commit
3eda4ec780
@ -599,12 +599,16 @@ class LoadConfig:
|
||||
mainly for profiling.
|
||||
"tensorizer" will use CoreWeave's tensorizer library for
|
||||
fast weight loading.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
||||
default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
@ -613,6 +617,13 @@ class LoadConfig:
|
||||
model_loader_extra_config)
|
||||
self._verify_load_format()
|
||||
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
if not isinstance(self.load_format, str):
|
||||
return
|
||||
@ -801,7 +812,9 @@ class SchedulerConfig:
|
||||
# for higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
if enable_chunked_prefill:
|
||||
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
|
||||
logger.info(
|
||||
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||
max_num_batched_tokens)
|
||||
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
@ -95,6 +95,7 @@ class EngineArgs:
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
num_lookahead_slots: int = 0
|
||||
model_loader_extra_config: Optional[dict] = None
|
||||
ignore_patterns: Optional[Union[str, List[str]]] = None
|
||||
preemption_mode: Optional[str] = None
|
||||
|
||||
scheduler_delay_factor: float = 0.0
|
||||
@ -619,6 +620,14 @@ class EngineArgs:
|
||||
'corresponding to the chosen load_format. '
|
||||
'This should be a JSON string that will be '
|
||||
'parsed into a dictionary.')
|
||||
parser.add_argument(
|
||||
'--ignore-patterns',
|
||||
action="append",
|
||||
type=str,
|
||||
default=[],
|
||||
help="The pattern(s) to ignore when loading the model."
|
||||
"Default to 'original/**/*' to avoid repeated loading of llama's "
|
||||
"checkpoints.")
|
||||
parser.add_argument(
|
||||
'--preemption-mode',
|
||||
type=str,
|
||||
@ -824,6 +833,7 @@ class EngineArgs:
|
||||
load_format=self.load_format,
|
||||
download_dir=self.download_dir,
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
ignore_patterns=self.ignore_patterns,
|
||||
)
|
||||
|
||||
prompt_adapter_config = PromptAdapterConfig(
|
||||
|
||||
@ -161,6 +161,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
cache_dir=self.load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
revision=revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
model_path = model
|
||||
@ -196,9 +197,13 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns, revision)
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
@ -489,9 +494,13 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
return download_weights_from_hf(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns, revision)
|
||||
return download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
@ -663,8 +672,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
matching_files = fnmatch.filter(repo_files, pattern)
|
||||
if matching_files:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path, self.load_config.download_dir,
|
||||
[pattern], revision)
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
return glob.glob(os.path.join(hf_folder, pattern)), pattern
|
||||
|
||||
raise RuntimeError(
|
||||
|
||||
@ -6,7 +6,7 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generator, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import filelock
|
||||
import huggingface_hub.constants
|
||||
@ -189,6 +189,7 @@ def download_weights_from_hf(
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: List[str],
|
||||
revision: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
@ -200,6 +201,9 @@ def download_weights_from_hf(
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
|
||||
filter out the weight files. Files matched by any of the patterns
|
||||
will be ignored.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
@ -223,6 +227,7 @@ def download_weights_from_hf(
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user