support ignore patterns in model loader (#6673)

This commit is contained in:
Simon Mo 2024-07-22 23:59:42 -07:00 committed by GitHub
parent 22fa2e35cb
commit 3eda4ec780
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 10 deletions

View File

@ -599,12 +599,16 @@ class LoadConfig:
mainly for profiling. mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for "tensorizer" will use CoreWeave's tensorizer library for
fast weight loading. fast weight loading.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
""" """
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field( model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict) default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
@ -613,6 +617,13 @@ class LoadConfig:
model_loader_extra_config) model_loader_extra_config)
self._verify_load_format() self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str): if not isinstance(self.load_format, str):
return return
@ -801,7 +812,9 @@ class SchedulerConfig:
# for higher throughput. # for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill: if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).") logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
max_num_batched_tokens)
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len

View File

@ -95,6 +95,7 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None model_loader_extra_config: Optional[dict] = None
ignore_patterns: Optional[Union[str, List[str]]] = None
preemption_mode: Optional[str] = None preemption_mode: Optional[str] = None
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
@ -619,6 +620,14 @@ class EngineArgs:
'corresponding to the chosen load_format. ' 'corresponding to the chosen load_format. '
'This should be a JSON string that will be ' 'This should be a JSON string that will be '
'parsed into a dictionary.') 'parsed into a dictionary.')
parser.add_argument(
'--ignore-patterns',
action="append",
type=str,
default=[],
help="The pattern(s) to ignore when loading the model."
"Default to 'original/**/*' to avoid repeated loading of llama's "
"checkpoints.")
parser.add_argument( parser.add_argument(
'--preemption-mode', '--preemption-mode',
type=str, type=str,
@ -824,6 +833,7 @@ class EngineArgs:
load_format=self.load_format, load_format=self.load_format,
download_dir=self.download_dir, download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config, model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
) )
prompt_adapter_config = PromptAdapterConfig( prompt_adapter_config = PromptAdapterConfig(

View File

@ -161,6 +161,7 @@ class DefaultModelLoader(BaseModelLoader):
cache_dir=self.load_config.download_dir, cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision, revision=revision,
ignore_patterns=self.load_config.ignore_patterns,
) )
else: else:
model_path = model model_path = model
@ -196,9 +197,13 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns += ["*.pt"] allow_patterns += ["*.pt"]
if not is_local: if not is_local:
hf_folder = download_weights_from_hf(model_name_or_path, hf_folder = download_weights_from_hf(
self.load_config.download_dir, model_name_or_path,
allow_patterns, revision) self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
@ -489,9 +494,13 @@ class ShardedStateLoader(BaseModelLoader):
return model_name_or_path return model_name_or_path
else: else:
allow_patterns = ["*.safetensors"] allow_patterns = ["*.safetensors"]
return download_weights_from_hf(model_name_or_path, return download_weights_from_hf(
self.load_config.download_dir, model_name_or_path,
allow_patterns, revision) self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
@ -663,8 +672,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
matching_files = fnmatch.filter(repo_files, pattern) matching_files = fnmatch.filter(repo_files, pattern)
if matching_files: if matching_files:
hf_folder = download_weights_from_hf( hf_folder = download_weights_from_hf(
model_name_or_path, self.load_config.download_dir, model_name_or_path,
[pattern], revision) self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return glob.glob(os.path.join(hf_folder, pattern)), pattern return glob.glob(os.path.join(hf_folder, pattern)), pattern
raise RuntimeError( raise RuntimeError(

View File

@ -6,7 +6,7 @@ import json
import os import os
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import Any, Generator, Iterable, List, Optional, Tuple from typing import Any, Generator, Iterable, List, Optional, Tuple, Union
import filelock import filelock
import huggingface_hub.constants import huggingface_hub.constants
@ -189,6 +189,7 @@ def download_weights_from_hf(
cache_dir: Optional[str], cache_dir: Optional[str],
allow_patterns: List[str], allow_patterns: List[str],
revision: Optional[str] = None, revision: Optional[str] = None,
ignore_patterns: Optional[Union[str, List[str]]] = None,
) -> str: ) -> str:
"""Download model weights from Hugging Face Hub. """Download model weights from Hugging Face Hub.
@ -200,6 +201,9 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be weight files. Files matched by any of the patterns will be
downloaded. downloaded.
revision (Optional[str]): The revision of the model. revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns: Returns:
str: The path to the downloaded model weights. str: The path to the downloaded model weights.
@ -223,6 +227,7 @@ def download_weights_from_hf(
hf_folder = snapshot_download( hf_folder = snapshot_download(
model_name_or_path, model_name_or_path,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir, cache_dir=cache_dir,
tqdm_class=DisabledTqdm, tqdm_class=DisabledTqdm,
revision=revision, revision=revision,