From c84e924287fbaf923994865806e7ebc93b4070e6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 6 May 2023 02:12:12 -0700 Subject: [PATCH] [Minor] Fix a dtype bug (#79) --- cacheflow/models/model_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 91ca4f32..ca838d5e 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -37,7 +37,11 @@ _MEMORY_ANALYZERS = { def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: - config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32) + # NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, 'torch_dtype', None) + if config_dtype is None: + config_dtype = torch.float32 if dtype == 'default': if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 models.