vllm/cacheflow/models/model_utils.py

81 lines
2.8 KiB
Python
Raw Normal View History

from typing import Union, Optional
2023-02-24 05:31:39 +08:00
import torch
2023-02-13 17:36:12 +08:00
import torch.nn as nn
2023-03-22 04:45:42 +08:00
from transformers import AutoConfig
2023-02-13 17:36:12 +08:00
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
2023-05-04 17:59:56 +08:00
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
2023-05-04 17:59:56 +08:00
from cacheflow.models.gpt2 import GPT2LMHeadModel
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
from cacheflow.models.llama import LlamaForCausalLM
2023-02-23 02:08:25 +08:00
from cacheflow.models.opt import OPTForCausalLM
from cacheflow.models.utils import get_torch_dtype
2023-02-13 17:36:12 +08:00
_MODELS = {
2023-05-04 17:59:56 +08:00
'gpt2': GPT2LMHeadModel,
'llama': LlamaForCausalLM,
2023-02-13 17:36:12 +08:00
'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM,
2023-02-13 17:36:12 +08:00
}
_MEMORY_ANALYZERS = {
2023-05-04 17:59:56 +08:00
'gpt2': GPT2MemoryAnalyzer,
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer,
2023-02-24 05:31:39 +08:00
}
2023-02-13 17:36:12 +08:00
2023-02-24 05:31:39 +08:00
def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
2023-02-24 05:31:39 +08:00
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
2023-03-22 04:45:42 +08:00
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
if use_dummy_weights:
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
model.initialize_dummy_weights()
else:
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
2023-03-22 04:45:42 +08:00
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')
def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
2023-03-29 14:48:56 +08:00
gpu_memory: int,
cpu_memory: int,
2023-03-22 04:45:42 +08:00
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
2023-03-29 14:48:56 +08:00
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')