Add code-revision config argument for Hugging Face Hub (#2892)

This commit is contained in:
Mark Mozolewski 2024-02-17 22:36:53 -08:00 committed by GitHub
parent 8f36444c4f
commit 786b7f18a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 11 deletions

View File

@ -44,6 +44,9 @@ class ModelConfig:
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default a tag name, or a commit id. If unspecified, will use the default
version. version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use branch name, a tag name, or a commit id. If unspecified, will use
the default version. the default version.
@ -70,6 +73,7 @@ class ModelConfig:
dtype: Union[str, torch.dtype], dtype: Union[str, torch.dtype],
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
@ -84,6 +88,7 @@ class ModelConfig:
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.code_revision = code_revision
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
@ -103,7 +108,8 @@ class ModelConfig:
self.download_dir = model_path self.download_dir = model_path
self.tokenizer = model_path self.tokenizer = model_path
self.hf_config = get_config(self.model, trust_remote_code, revision) self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config, self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len) max_model_len)

View File

@ -32,6 +32,7 @@ class EngineArgs:
max_paddings: int = 256 max_paddings: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
@ -75,6 +76,13 @@ class EngineArgs:
help='the specific model version to use. It can be a branch ' help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument(
'--code-revision',
type=str,
default=None,
help='the specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=str, type=str,
@ -279,13 +287,12 @@ class EngineArgs:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]: DeviceConfig, Optional[LoRAConfig]]:
device_config = DeviceConfig(self.device) device_config = DeviceConfig(self.device)
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(
self.tokenizer_mode, self.trust_remote_code, self.model, self.tokenizer, self.tokenizer_mode,
self.download_dir, self.load_format, self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.tokenizer_revision, self.max_model_len, self.quantization,
self.quantization, self.enforce_eager, self.enforce_eager, self.max_context_len_to_capture)
self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,

View File

@ -16,10 +16,14 @@ _CONFIG_REGISTRY = {
def get_config(model: str, def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig: revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision) model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
except ValueError as e: except ValueError as e:
if (not trust_remote_code and if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)): "requires you to execute the configuration file" in str(e)):
@ -33,5 +37,7 @@ def get_config(model: str,
raise e raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config return config