Add code-revision config argument for Hugging Face Hub (#2892)

This commit is contained in:
Mark Mozolewski 2024-02-17 22:36:53 -08:00 committed by GitHub
parent 8f36444c4f
commit 786b7f18a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 11 deletions

View File

@ -44,6 +44,9 @@ class ModelConfig:
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
@ -70,6 +73,7 @@ class ModelConfig:
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
@ -84,6 +88,7 @@ class ModelConfig:
self.load_format = load_format
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
@ -103,7 +108,8 @@ class ModelConfig:
self.download_dir = model_path
self.tokenizer = model_path
self.hf_config = get_config(self.model, trust_remote_code, revision)
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len)

View File

@ -32,6 +32,7 @@ class EngineArgs:
max_paddings: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
@ -75,6 +76,13 @@ class EngineArgs:
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=str,
default=None,
help='the specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=str,
@ -279,13 +287,12 @@ class EngineArgs:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization, self.enforce_eager,
self.max_context_len_to_capture)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,

View File

@ -16,10 +16,14 @@ _CONFIG_REGISTRY = {
def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig:
revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision)
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
@ -33,5 +37,7 @@ def get_config(model: str,
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config