feat(config): support parsing torch.dtype (#1641)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
parent
b514d3c496
commit
65ea2ddf17
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -58,7 +58,7 @@ class ModelConfig:
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
download_dir: Optional[str],
|
download_dir: Optional[str],
|
||||||
load_format: str,
|
load_format: str,
|
||||||
dtype: str,
|
dtype: Union[str, torch.dtype],
|
||||||
seed: int,
|
seed: int,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
@ -331,7 +331,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
|||||||
|
|
||||||
def _get_and_verify_dtype(
|
def _get_and_verify_dtype(
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
dtype: str,
|
dtype: Union[str, torch.dtype],
|
||||||
) -> torch.dtype:
|
) -> torch.dtype:
|
||||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||||
# because config.torch_dtype can be None.
|
# because config.torch_dtype can be None.
|
||||||
@ -339,17 +339,23 @@ def _get_and_verify_dtype(
|
|||||||
if config_dtype is None:
|
if config_dtype is None:
|
||||||
config_dtype = torch.float32
|
config_dtype = torch.float32
|
||||||
|
|
||||||
dtype = dtype.lower()
|
if isinstance(dtype, str):
|
||||||
if dtype == "auto":
|
dtype = dtype.lower()
|
||||||
if config_dtype == torch.float32:
|
if dtype == "auto":
|
||||||
# Following the common practice, we use float16 for float32 models.
|
if config_dtype == torch.float32:
|
||||||
torch_dtype = torch.float16
|
# Following the common practice, we use float16 for float32
|
||||||
|
# models.
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = config_dtype
|
||||||
else:
|
else:
|
||||||
torch_dtype = config_dtype
|
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||||
|
raise ValueError(f"Unknown dtype: {dtype}")
|
||||||
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
elif isinstance(dtype, torch.dtype):
|
||||||
|
torch_dtype = dtype
|
||||||
else:
|
else:
|
||||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
raise ValueError(f"Unknown dtype: {dtype}")
|
||||||
raise ValueError(f"Unknown dtype: {dtype}")
|
|
||||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
|
||||||
|
|
||||||
# Verify the dtype.
|
# Verify the dtype.
|
||||||
if torch_dtype != config_dtype:
|
if torch_dtype != config_dtype:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user