feat(config): support parsing torch.dtype (#1641)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham 2023-11-16 04:31:06 -05:00 committed by GitHub
parent b514d3c496
commit 65ea2ddf17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union
import torch
from transformers import PretrainedConfig
@ -58,7 +58,7 @@ class ModelConfig:
trust_remote_code: bool,
download_dir: Optional[str],
load_format: str,
dtype: str,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
@ -331,7 +331,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: str,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
@ -339,17 +339,23 @@ def _get_and_verify_dtype(
if config_dtype is None:
config_dtype = torch.float32
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = config_dtype
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype: