Add code-revision config argument for Hugging Face Hub (#2892)
This commit is contained in:
parent
8f36444c4f
commit
786b7f18a5
@ -44,6 +44,9 @@ class ModelConfig:
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id. If unspecified, will use the default
|
||||
version.
|
||||
code_revision: The specific revision to use for the model code on
|
||||
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||
commit id. If unspecified, will use the default version.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use
|
||||
the default version.
|
||||
@ -70,6 +73,7 @@ class ModelConfig:
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
@ -84,6 +88,7 @@ class ModelConfig:
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
self.enforce_eager = enforce_eager
|
||||
@ -103,7 +108,8 @@ class ModelConfig:
|
||||
self.download_dir = model_path
|
||||
self.tokenizer = model_path
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision)
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
||||
max_model_len)
|
||||
|
||||
@ -32,6 +32,7 @@ class EngineArgs:
|
||||
max_paddings: int = 256
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
code_revision: Optional[str] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
@ -75,6 +76,13 @@ class EngineArgs:
|
||||
help='the specific model version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument(
|
||||
'--code-revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific revision to use for the model code on '
|
||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||
'commit id. If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-revision',
|
||||
type=str,
|
||||
@ -279,13 +287,12 @@ class EngineArgs:
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||
DeviceConfig, Optional[LoRAConfig]]:
|
||||
device_config = DeviceConfig(self.device)
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization, self.enforce_eager,
|
||||
self.max_context_len_to_capture)
|
||||
model_config = ModelConfig(
|
||||
self.model, self.tokenizer, self.tokenizer_mode,
|
||||
self.trust_remote_code, self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision, self.code_revision,
|
||||
self.tokenizer_revision, self.max_model_len, self.quantization,
|
||||
self.enforce_eager, self.max_context_len_to_capture)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
|
||||
@ -16,10 +16,14 @@ _CONFIG_REGISTRY = {
|
||||
|
||||
def get_config(model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None) -> PretrainedConfig:
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None) -> PretrainedConfig:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision)
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code and
|
||||
"requires you to execute the configuration file" in str(e)):
|
||||
@ -33,5 +37,7 @@ def get_config(model: str,
|
||||
raise e
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
return config
|
||||
|
||||
Loading…
Reference in New Issue
Block a user