Add code-revision config argument for Hugging Face Hub (#2892)
This commit is contained in:
parent
8f36444c4f
commit
786b7f18a5
@ -44,6 +44,9 @@ class ModelConfig:
|
|||||||
revision: The specific model version to use. It can be a branch name,
|
revision: The specific model version to use. It can be a branch name,
|
||||||
a tag name, or a commit id. If unspecified, will use the default
|
a tag name, or a commit id. If unspecified, will use the default
|
||||||
version.
|
version.
|
||||||
|
code_revision: The specific revision to use for the model code on
|
||||||
|
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||||
|
commit id. If unspecified, will use the default version.
|
||||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||||
branch name, a tag name, or a commit id. If unspecified, will use
|
branch name, a tag name, or a commit id. If unspecified, will use
|
||||||
the default version.
|
the default version.
|
||||||
@ -70,6 +73,7 @@ class ModelConfig:
|
|||||||
dtype: Union[str, torch.dtype],
|
dtype: Union[str, torch.dtype],
|
||||||
seed: int,
|
seed: int,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None,
|
||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
@ -84,6 +88,7 @@ class ModelConfig:
|
|||||||
self.load_format = load_format
|
self.load_format = load_format
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
|
self.code_revision = code_revision
|
||||||
self.tokenizer_revision = tokenizer_revision
|
self.tokenizer_revision = tokenizer_revision
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
self.enforce_eager = enforce_eager
|
self.enforce_eager = enforce_eager
|
||||||
@ -103,7 +108,8 @@ class ModelConfig:
|
|||||||
self.download_dir = model_path
|
self.download_dir = model_path
|
||||||
self.tokenizer = model_path
|
self.tokenizer = model_path
|
||||||
|
|
||||||
self.hf_config = get_config(self.model, trust_remote_code, revision)
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
|
code_revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
||||||
max_model_len)
|
max_model_len)
|
||||||
|
|||||||
@ -32,6 +32,7 @@ class EngineArgs:
|
|||||||
max_paddings: int = 256
|
max_paddings: int = 256
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
code_revision: Optional[str] = None
|
||||||
tokenizer_revision: Optional[str] = None
|
tokenizer_revision: Optional[str] = None
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
enforce_eager: bool = False
|
enforce_eager: bool = False
|
||||||
@ -75,6 +76,13 @@ class EngineArgs:
|
|||||||
help='the specific model version to use. It can be a branch '
|
help='the specific model version to use. It can be a branch '
|
||||||
'name, a tag name, or a commit id. If unspecified, will use '
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
'the default version.')
|
'the default version.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--code-revision',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='the specific revision to use for the model code on '
|
||||||
|
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||||
|
'commit id. If unspecified, will use the default version.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer-revision',
|
'--tokenizer-revision',
|
||||||
type=str,
|
type=str,
|
||||||
@ -279,13 +287,12 @@ class EngineArgs:
|
|||||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||||
DeviceConfig, Optional[LoRAConfig]]:
|
DeviceConfig, Optional[LoRAConfig]]:
|
||||||
device_config = DeviceConfig(self.device)
|
device_config = DeviceConfig(self.device)
|
||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.model, self.tokenizer, self.tokenizer_mode,
|
||||||
self.download_dir, self.load_format,
|
self.trust_remote_code, self.download_dir, self.load_format,
|
||||||
self.dtype, self.seed, self.revision,
|
self.dtype, self.seed, self.revision, self.code_revision,
|
||||||
self.tokenizer_revision, self.max_model_len,
|
self.tokenizer_revision, self.max_model_len, self.quantization,
|
||||||
self.quantization, self.enforce_eager,
|
self.enforce_eager, self.max_context_len_to_capture)
|
||||||
self.max_context_len_to_capture)
|
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space, self.kv_cache_dtype,
|
self.swap_space, self.kv_cache_dtype,
|
||||||
|
|||||||
@ -16,10 +16,14 @@ _CONFIG_REGISTRY = {
|
|||||||
|
|
||||||
def get_config(model: str,
|
def get_config(model: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
revision: Optional[str] = None) -> PretrainedConfig:
|
revision: Optional[str] = None,
|
||||||
|
code_revision: Optional[str] = None) -> PretrainedConfig:
|
||||||
try:
|
try:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code, revision=revision)
|
model,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
revision=revision,
|
||||||
|
code_revision=code_revision)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if (not trust_remote_code and
|
if (not trust_remote_code and
|
||||||
"requires you to execute the configuration file" in str(e)):
|
"requires you to execute the configuration file" in str(e)):
|
||||||
@ -33,5 +37,7 @@ def get_config(model: str,
|
|||||||
raise e
|
raise e
|
||||||
if config.model_type in _CONFIG_REGISTRY:
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
config = config_class.from_pretrained(model, revision=revision)
|
config = config_class.from_pretrained(model,
|
||||||
|
revision=revision,
|
||||||
|
code_revision=code_revision)
|
||||||
return config
|
return config
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user