vllm/vllm/model_executor/models/__init__.py

104 lines
4.1 KiB
Python
Raw Normal View History

2023-12-13 14:21:45 +08:00
import importlib
from typing import List, Optional, Type
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip, is_neuron
2023-12-13 14:21:45 +08:00
logger = init_logger(__name__)
# Architecture -> (module, class).
_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
2023-12-13 14:21:45 +08:00
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
2023-12-13 14:21:45 +08:00
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
2024-02-22 01:34:30 +08:00
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
2023-12-13 14:21:45 +08:00
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
2024-02-02 01:27:40 +08:00
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
2023-12-13 14:21:45 +08:00
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
2023-12-13 14:21:45 +08:00
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
2024-01-31 08:34:10 +08:00
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
2023-12-13 14:21:45 +08:00
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
2024-02-19 13:05:15 +08:00
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
2023-12-13 14:21:45 +08:00
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
2024-01-13 04:16:49 +08:00
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
2023-12-13 14:21:45 +08:00
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
2024-01-23 06:34:21 +08:00
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
2023-12-13 14:21:45 +08:00
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
2024-01-17 12:32:40 +08:00
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
2024-02-27 10:31:10 +08:00
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
2023-12-13 14:21:45 +08:00
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []
2023-12-13 14:21:45 +08:00
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
2024-01-23 06:34:21 +08:00
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
2023-12-13 14:21:45 +08:00
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
2023-12-13 14:21:45 +08:00
}
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
2023-12-13 14:21:45 +08:00
class ModelRegistry:
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
elif is_neuron():
if model_arch not in _NEURON_SUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"Neuron for now.")
2023-12-13 14:21:45 +08:00
module_name, model_cls_name = _MODELS[model_arch]
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
2023-12-13 14:21:45 +08:00
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
2023-06-17 18:07:40 +08:00
__all__ = [
2023-12-13 14:21:45 +08:00
"ModelRegistry",
2023-06-17 18:07:40 +08:00
]