Use dtype from model config & Add Dolly V2 (#63)
This commit is contained in:
parent
e548c1488a
commit
189ae23133
@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
help='save a numpy copy of model weights for faster loading')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||
# NOTE(woosuk): FlashAttention does not support float32.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
|
||||
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
|
||||
help=('data type for model weights and activations. '
|
||||
'The "default" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.'))
|
||||
# Parallel arguments
|
||||
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Union, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
|
||||
@ -22,6 +23,7 @@ _MODELS = {
|
||||
'opt': OPTForCausalLM,
|
||||
'stablelm': GPTNeoXForCausalLM,
|
||||
'pythia': GPTNeoXForCausalLM,
|
||||
'dolly-v2': GPTNeoXForCausalLM,
|
||||
}
|
||||
|
||||
_MEMORY_ANALYZERS = {
|
||||
@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = {
|
||||
'opt': OPTMemoryAnalyzer,
|
||||
'stablelm': GPTNeoXMemoryAnalyzer,
|
||||
'pythia': GPTNeoXMemoryAnalyzer,
|
||||
'dolly-v2': GPTNeoXMemoryAnalyzer,
|
||||
}
|
||||
|
||||
|
||||
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
||||
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
|
||||
if dtype == 'default':
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
if torch_dtype != config_dtype and config_dtype != torch.float32:
|
||||
# TODO(woosuk): Allow using float16 for bfloat16 models and
|
||||
# vice versa. Print a warning message and continue.
|
||||
raise ValueError(
|
||||
f'Cannot use {torch_dtype} for {config_dtype} model.')
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def get_model(
|
||||
model_name: str,
|
||||
dtype: Union[torch.dtype, str],
|
||||
dtype: str,
|
||||
cache_dir: Optional[str],
|
||||
use_dummy_weights: bool,
|
||||
use_np_cache: bool,
|
||||
) -> nn.Module:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
torch_dtype = _get_dtype(config, dtype)
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
for model_class_name, model_class in _MODELS.items():
|
||||
if model_class_name in model_name:
|
||||
if use_dummy_weights:
|
||||
@ -66,12 +87,13 @@ def get_model(
|
||||
def get_memory_analyzer(
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
dtype: Union[torch.dtype, str],
|
||||
dtype: str,
|
||||
gpu_memory: int,
|
||||
cpu_memory: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
) -> CacheFlowMemoryAnalyzer:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
torch_dtype = _get_dtype(config, dtype)
|
||||
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
|
||||
if model_class in model_name:
|
||||
return memory_analyzer(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user