diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index e2a9956f..6f28b96f 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser): help='save a numpy copy of model weights for faster loading') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') # NOTE(woosuk): FlashAttention does not support float32. - parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type') + parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'], + help=('data type for model weights and activations. ' + 'The "default" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.')) # Parallel arguments parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 67fdd0f2..91ca4f32 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -1,8 +1,9 @@ -from typing import Union, Optional +from typing import Optional import torch import torch.nn as nn from transformers import AutoConfig +from transformers import PretrainedConfig from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer @@ -22,6 +23,7 @@ _MODELS = { 'opt': OPTForCausalLM, 'stablelm': GPTNeoXForCausalLM, 'pythia': GPTNeoXForCausalLM, + 'dolly-v2': GPTNeoXForCausalLM, } _MEMORY_ANALYZERS = { @@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = { 'opt': OPTMemoryAnalyzer, 'stablelm': GPTNeoXMemoryAnalyzer, 'pythia': GPTNeoXMemoryAnalyzer, + 'dolly-v2': GPTNeoXMemoryAnalyzer, } +def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: + config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32) + if dtype == 'default': + if config_dtype == torch.float32: + # Following the common practice, we use float16 for float32 models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + torch_dtype = get_torch_dtype(dtype) + if torch_dtype != config_dtype and config_dtype != torch.float32: + # TODO(woosuk): Allow using float16 for bfloat16 models and + # vice versa. Print a warning message and continue. + raise ValueError( + f'Cannot use {torch_dtype} for {config_dtype} model.') + return torch_dtype + + def get_model( model_name: str, - dtype: Union[torch.dtype, str], + dtype: str, cache_dir: Optional[str], use_dummy_weights: bool, use_np_cache: bool, ) -> nn.Module: - torch_dtype = get_torch_dtype(dtype) - torch.set_default_dtype(torch_dtype) config = AutoConfig.from_pretrained(model_name) + torch_dtype = _get_dtype(config, dtype) + torch.set_default_dtype(torch_dtype) for model_class_name, model_class in _MODELS.items(): if model_class_name in model_name: if use_dummy_weights: @@ -66,12 +87,13 @@ def get_model( def get_memory_analyzer( model_name: str, block_size: int, - dtype: Union[torch.dtype, str], + dtype: str, gpu_memory: int, cpu_memory: int, tensor_parallel_size: int = 1, ) -> CacheFlowMemoryAnalyzer: - torch_dtype = get_torch_dtype(dtype) + config = AutoConfig.from_pretrained(model_name) + torch_dtype = _get_dtype(config, dtype) for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): if model_class in model_name: return memory_analyzer(