Use dtype from model config & Add Dolly V2 (#63)

This commit is contained in:
Woosuk Kwon 2023-05-04 03:05:37 -07:00 committed by GitHub
parent e548c1488a
commit 189ae23133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 7 deletions

View File

@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')

View File

@ -1,8 +1,9 @@
from typing import Union, Optional
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoConfig
from transformers import PretrainedConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
@ -22,6 +23,7 @@ _MODELS = {
'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM,
'dolly-v2': GPTNeoXForCausalLM,
}
_MEMORY_ANALYZERS = {
@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = {
'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer,
'dolly-v2': GPTNeoXMemoryAnalyzer,
}
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
if dtype == 'default':
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = get_torch_dtype(dtype)
if torch_dtype != config_dtype and config_dtype != torch.float32:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f'Cannot use {torch_dtype} for {config_dtype} model.')
return torch_dtype
def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
dtype: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
if use_dummy_weights:
@ -66,12 +87,13 @@ def get_model(
def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
dtype: str,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(