support ignore patterns in model loader (#6673)

This commit is contained in:
Simon Mo 2024-07-22 23:59:42 -07:00 committed by GitHub
parent 22fa2e35cb
commit 3eda4ec780
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 10 deletions

View File

@ -599,12 +599,16 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
@ -613,6 +617,13 @@ class LoadConfig:
model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
@ -801,7 +812,9 @@ class SchedulerConfig:
# for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
max_num_batched_tokens)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len

View File

@ -95,6 +95,7 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
ignore_patterns: Optional[Union[str, List[str]]] = None
preemption_mode: Optional[str] = None
scheduler_delay_factor: float = 0.0
@ -619,6 +620,14 @@ class EngineArgs:
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument(
'--ignore-patterns',
action="append",
type=str,
default=[],
help="The pattern(s) to ignore when loading the model."
"Default to 'original/**/*' to avoid repeated loading of llama's "
"checkpoints.")
parser.add_argument(
'--preemption-mode',
type=str,
@ -824,6 +833,7 @@ class EngineArgs:
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
prompt_adapter_config = PromptAdapterConfig(

View File

@ -161,6 +161,7 @@ class DefaultModelLoader(BaseModelLoader):
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
model_path = model
@ -196,9 +197,13 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns += ["*.pt"]
if not is_local:
hf_folder = download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns, revision)
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
@ -489,9 +494,13 @@ class ShardedStateLoader(BaseModelLoader):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns, revision)
return download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
@ -663,8 +672,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path, self.load_config.download_dir,
[pattern], revision)
model_name_or_path,
self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return glob.glob(os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(

View File

@ -6,7 +6,7 @@ import json
import os
import tempfile
from collections import defaultdict
from typing import Any, Generator, Iterable, List, Optional, Tuple
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union
import filelock
import huggingface_hub.constants
@ -189,6 +189,7 @@ def download_weights_from_hf(
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
ignore_patterns: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
@ -200,6 +201,9 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
@ -223,6 +227,7 @@ def download_weights_from_hf(
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,