support ignore patterns in model loader (#6673)
This commit is contained in:
parent
22fa2e35cb
commit
3eda4ec780
@ -599,12 +599,16 @@ class LoadConfig:
|
|||||||
mainly for profiling.
|
mainly for profiling.
|
||||||
"tensorizer" will use CoreWeave's tensorizer library for
|
"tensorizer" will use CoreWeave's tensorizer library for
|
||||||
fast weight loading.
|
fast weight loading.
|
||||||
|
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||||
|
Default to "original/**/*" to avoid repeated loading of llama's
|
||||||
|
checkpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
||||||
default_factory=dict)
|
default_factory=dict)
|
||||||
|
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
@ -613,6 +617,13 @@ class LoadConfig:
|
|||||||
model_loader_extra_config)
|
model_loader_extra_config)
|
||||||
self._verify_load_format()
|
self._verify_load_format()
|
||||||
|
|
||||||
|
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||||
|
logger.info(
|
||||||
|
"Ignoring the following patterns when downloading weights: %s",
|
||||||
|
self.ignore_patterns)
|
||||||
|
else:
|
||||||
|
self.ignore_patterns = ["original/**/*"]
|
||||||
|
|
||||||
def _verify_load_format(self) -> None:
|
def _verify_load_format(self) -> None:
|
||||||
if not isinstance(self.load_format, str):
|
if not isinstance(self.load_format, str):
|
||||||
return
|
return
|
||||||
@ -801,7 +812,9 @@ class SchedulerConfig:
|
|||||||
# for higher throughput.
|
# for higher throughput.
|
||||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||||
if enable_chunked_prefill:
|
if enable_chunked_prefill:
|
||||||
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
|
logger.info(
|
||||||
|
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||||
|
max_num_batched_tokens)
|
||||||
|
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
|
|||||||
@ -95,6 +95,7 @@ class EngineArgs:
|
|||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
model_loader_extra_config: Optional[dict] = None
|
model_loader_extra_config: Optional[dict] = None
|
||||||
|
ignore_patterns: Optional[Union[str, List[str]]] = None
|
||||||
preemption_mode: Optional[str] = None
|
preemption_mode: Optional[str] = None
|
||||||
|
|
||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = 0.0
|
||||||
@ -619,6 +620,14 @@ class EngineArgs:
|
|||||||
'corresponding to the chosen load_format. '
|
'corresponding to the chosen load_format. '
|
||||||
'This should be a JSON string that will be '
|
'This should be a JSON string that will be '
|
||||||
'parsed into a dictionary.')
|
'parsed into a dictionary.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ignore-patterns',
|
||||||
|
action="append",
|
||||||
|
type=str,
|
||||||
|
default=[],
|
||||||
|
help="The pattern(s) to ignore when loading the model."
|
||||||
|
"Default to 'original/**/*' to avoid repeated loading of llama's "
|
||||||
|
"checkpoints.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--preemption-mode',
|
'--preemption-mode',
|
||||||
type=str,
|
type=str,
|
||||||
@ -824,6 +833,7 @@ class EngineArgs:
|
|||||||
load_format=self.load_format,
|
load_format=self.load_format,
|
||||||
download_dir=self.download_dir,
|
download_dir=self.download_dir,
|
||||||
model_loader_extra_config=self.model_loader_extra_config,
|
model_loader_extra_config=self.model_loader_extra_config,
|
||||||
|
ignore_patterns=self.ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_adapter_config = PromptAdapterConfig(
|
prompt_adapter_config = PromptAdapterConfig(
|
||||||
|
|||||||
@ -161,6 +161,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
cache_dir=self.load_config.download_dir,
|
cache_dir=self.load_config.download_dir,
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = model
|
model_path = model
|
||||||
@ -196,9 +197,13 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
allow_patterns += ["*.pt"]
|
allow_patterns += ["*.pt"]
|
||||||
|
|
||||||
if not is_local:
|
if not is_local:
|
||||||
hf_folder = download_weights_from_hf(model_name_or_path,
|
hf_folder = download_weights_from_hf(
|
||||||
self.load_config.download_dir,
|
model_name_or_path,
|
||||||
allow_patterns, revision)
|
self.load_config.download_dir,
|
||||||
|
allow_patterns,
|
||||||
|
revision,
|
||||||
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
|
|
||||||
@ -489,9 +494,13 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
return model_name_or_path
|
return model_name_or_path
|
||||||
else:
|
else:
|
||||||
allow_patterns = ["*.safetensors"]
|
allow_patterns = ["*.safetensors"]
|
||||||
return download_weights_from_hf(model_name_or_path,
|
return download_weights_from_hf(
|
||||||
self.load_config.download_dir,
|
model_name_or_path,
|
||||||
allow_patterns, revision)
|
self.load_config.download_dir,
|
||||||
|
allow_patterns,
|
||||||
|
revision,
|
||||||
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
|
)
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
@ -663,8 +672,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
matching_files = fnmatch.filter(repo_files, pattern)
|
matching_files = fnmatch.filter(repo_files, pattern)
|
||||||
if matching_files:
|
if matching_files:
|
||||||
hf_folder = download_weights_from_hf(
|
hf_folder = download_weights_from_hf(
|
||||||
model_name_or_path, self.load_config.download_dir,
|
model_name_or_path,
|
||||||
[pattern], revision)
|
self.load_config.download_dir,
|
||||||
|
[pattern],
|
||||||
|
revision,
|
||||||
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
|
)
|
||||||
return glob.glob(os.path.join(hf_folder, pattern)), pattern
|
return glob.glob(os.path.join(hf_folder, pattern)), pattern
|
||||||
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Generator, Iterable, List, Optional, Tuple
|
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import filelock
|
import filelock
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
@ -189,6 +189,7 @@ def download_weights_from_hf(
|
|||||||
cache_dir: Optional[str],
|
cache_dir: Optional[str],
|
||||||
allow_patterns: List[str],
|
allow_patterns: List[str],
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
|
ignore_patterns: Optional[Union[str, List[str]]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Download model weights from Hugging Face Hub.
|
"""Download model weights from Hugging Face Hub.
|
||||||
|
|
||||||
@ -200,6 +201,9 @@ def download_weights_from_hf(
|
|||||||
weight files. Files matched by any of the patterns will be
|
weight files. Files matched by any of the patterns will be
|
||||||
downloaded.
|
downloaded.
|
||||||
revision (Optional[str]): The revision of the model.
|
revision (Optional[str]): The revision of the model.
|
||||||
|
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
|
||||||
|
filter out the weight files. Files matched by any of the patterns
|
||||||
|
will be ignored.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the downloaded model weights.
|
str: The path to the downloaded model weights.
|
||||||
@ -223,6 +227,7 @@ def download_weights_from_hf(
|
|||||||
hf_folder = snapshot_download(
|
hf_folder = snapshot_download(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
tqdm_class=DisabledTqdm,
|
tqdm_class=DisabledTqdm,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user