From 65ea2ddf172a7234017f11d161ce87141deff3a2 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 16 Nov 2023 04:31:06 -0500 Subject: [PATCH] feat(config): support parsing torch.dtype (#1641) Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- vllm/config.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d91e18c2..fda9a268 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch from transformers import PretrainedConfig @@ -58,7 +58,7 @@ class ModelConfig: trust_remote_code: bool, download_dir: Optional[str], load_format: str, - dtype: str, + dtype: Union[str, torch.dtype], seed: int, revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, @@ -331,7 +331,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = { def _get_and_verify_dtype( config: PretrainedConfig, - dtype: str, + dtype: Union[str, torch.dtype], ) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. @@ -339,17 +339,23 @@ def _get_and_verify_dtype( if config_dtype is None: config_dtype = torch.float32 - dtype = dtype.lower() - if dtype == "auto": - if config_dtype == torch.float32: - # Following the common practice, we use float16 for float32 models. - torch_dtype = torch.float16 + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype else: - torch_dtype = config_dtype + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype else: - if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: - raise ValueError(f"Unknown dtype: {dtype}") - torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + raise ValueError(f"Unknown dtype: {dtype}") # Verify the dtype. if torch_dtype != config_dtype: