Use dtype from model config & Add Dolly V2 (#63)

This commit is contained in:
Woosuk Kwon 2023-05-04 03:05:37 -07:00 committed by GitHub
parent e548c1488a
commit 189ae23133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 7 deletions

View File

@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
help='save a numpy copy of model weights for faster loading') help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32. # NOTE(woosuk): FlashAttention does not support float32.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type') parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments # Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')

View File

@ -1,8 +1,9 @@
from typing import Union, Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoConfig from transformers import AutoConfig
from transformers import PretrainedConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
@ -22,6 +23,7 @@ _MODELS = {
'opt': OPTForCausalLM, 'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM, 'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM, 'pythia': GPTNeoXForCausalLM,
'dolly-v2': GPTNeoXForCausalLM,
} }
_MEMORY_ANALYZERS = { _MEMORY_ANALYZERS = {
@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = {
'opt': OPTMemoryAnalyzer, 'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer, 'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer, 'pythia': GPTNeoXMemoryAnalyzer,
'dolly-v2': GPTNeoXMemoryAnalyzer,
} }
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
if dtype == 'default':
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = get_torch_dtype(dtype)
if torch_dtype != config_dtype and config_dtype != torch.float32:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f'Cannot use {torch_dtype} for {config_dtype} model.')
return torch_dtype
def get_model( def get_model(
model_name: str, model_name: str,
dtype: Union[torch.dtype, str], dtype: str,
cache_dir: Optional[str], cache_dir: Optional[str],
use_dummy_weights: bool, use_dummy_weights: bool,
use_np_cache: bool, use_np_cache: bool,
) -> nn.Module: ) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
for model_class_name, model_class in _MODELS.items(): for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name: if model_class_name in model_name:
if use_dummy_weights: if use_dummy_weights:
@ -66,12 +87,13 @@ def get_model(
def get_memory_analyzer( def get_memory_analyzer(
model_name: str, model_name: str,
block_size: int, block_size: int,
dtype: Union[torch.dtype, str], dtype: str,
gpu_memory: int, gpu_memory: int,
cpu_memory: int, cpu_memory: int,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer: ) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype) config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name: if model_class in model_name:
return memory_analyzer( return memory_analyzer(