vllm/vllm/model_executor/models/__init__.py

189 lines
7.5 KiB
Python
Raw Normal View History

2024-07-20 03:10:56 +08:00
import functools
2023-12-13 14:21:45 +08:00
import importlib
from typing import Dict, List, Optional, Tuple, Type
2023-12-13 14:21:45 +08:00
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
2023-12-13 14:21:45 +08:00
logger = init_logger(__name__)
# Architecture -> (module, class).
_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
2023-12-13 14:21:45 +08:00
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
2023-12-13 14:21:45 +08:00
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
2024-03-28 04:01:46 +08:00
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
2023-12-13 14:21:45 +08:00
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
2024-02-22 01:34:30 +08:00
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
2024-06-28 04:33:56 +08:00
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
2023-12-13 14:21:45 +08:00
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
2024-02-02 01:27:40 +08:00
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
2023-12-13 14:21:45 +08:00
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
2023-12-13 14:21:45 +08:00
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
2023-12-13 14:21:45 +08:00
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
2024-01-31 08:34:10 +08:00
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
2023-12-13 14:21:45 +08:00
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
2024-04-08 18:28:36 +08:00
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
2023-12-13 14:21:45 +08:00
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
2024-01-13 04:16:49 +08:00
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
2024-04-25 11:06:57 +08:00
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
2023-12-13 14:21:45 +08:00
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
2024-01-23 06:34:21 +08:00
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
2023-12-13 14:21:45 +08:00
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
2024-01-17 12:32:40 +08:00
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
2024-02-27 10:31:10 +08:00
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
2023-12-13 14:21:45 +08:00
}
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_MODELS = {
**_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_CONDITIONAL_GENERATION_MODELS
}
# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
2023-12-13 14:21:45 +08:00
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
2023-12-13 14:21:45 +08:00
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
2024-01-23 06:34:21 +08:00
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
2023-12-13 14:21:45 +08:00
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
2023-12-13 14:21:45 +08:00
}
class ModelRegistry:
2024-07-20 03:10:56 +08:00
@staticmethod
@functools.lru_cache(maxsize=128)
def _get_model(model_arch: str):
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
2023-12-13 14:21:45 +08:00
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
2023-12-13 14:21:45 +08:00
if model_arch not in _MODELS:
return None
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
2023-12-13 14:21:45 +08:00
2024-07-20 03:10:56 +08:00
return ModelRegistry._get_model(model_arch)
2023-12-13 14:21:45 +08:00
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
2023-12-13 14:21:45 +08:00
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
@staticmethod
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
2023-06-17 18:07:40 +08:00
__all__ = [
2023-12-13 14:21:45 +08:00
"ModelRegistry",
2023-06-17 18:07:40 +08:00
]