[Misc] Move registry to its own file (#9064)
This commit is contained in:
parent
0f6d7a9a34
commit
0e36fd4909
@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
|
|||||||
5. Register your model
|
5. Register your model
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_.
|
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.
|
||||||
|
|
||||||
6. Out-of-Tree Model Integration
|
6. Out-of-Tree Model Integration
|
||||||
--------------------------------------------
|
--------------------------------------------
|
||||||
|
|||||||
@ -3,13 +3,13 @@ import warnings
|
|||||||
import pytest
|
import pytest
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
from vllm.model_executor.models import _MODELS, ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import fork_new_process_for_each_test
|
from ..utils import fork_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_arch", _MODELS)
|
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
|
||||||
def test_registry_imports(model_arch):
|
def test_registry_imports(model_arch):
|
||||||
# Ensure all model classes can be imported successfully
|
# Ensure all model classes can be imported successfully
|
||||||
ModelRegistry.resolve_model_cls(model_arch)
|
ModelRegistry.resolve_model_cls(model_arch)
|
||||||
|
|||||||
@ -24,8 +24,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
|||||||
from vllm.lora.punica import PunicaWrapper
|
from vllm.lora.punica import PunicaWrapper
|
||||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||||
parse_fine_tuned_lora_name, replace_submodule)
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
from vllm.model_executor.models.interfaces import (SupportsLoRA,
|
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||||
supports_multimodal)
|
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.utils import PPMissingLayer
|
from vllm.model_executor.models.utils import PPMissingLayer
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|||||||
@ -41,8 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
|
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
|
||||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
||||||
safetensors_weights_iterator)
|
safetensors_weights_iterator)
|
||||||
from vllm.model_executor.models.interfaces import (has_inner_state,
|
from vllm.model_executor.models import (has_inner_state, supports_lora,
|
||||||
supports_lora,
|
|
||||||
supports_multimodal)
|
supports_multimodal)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|||||||
@ -1,325 +1,16 @@
|
|||||||
import importlib
|
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||||
import string
|
SupportsPP, has_inner_state, supports_lora,
|
||||||
import subprocess
|
supports_multimodal, supports_pp)
|
||||||
import sys
|
from .registry import ModelRegistry
|
||||||
import uuid
|
|
||||||
from functools import lru_cache, partial
|
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.utils import is_hip
|
|
||||||
|
|
||||||
from .interfaces import supports_multimodal, supports_pp
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
_GENERATION_MODELS = {
|
|
||||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
|
||||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
|
||||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
|
||||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
|
||||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
|
||||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
|
||||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
|
||||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
|
||||||
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
|
||||||
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
|
||||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
|
||||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
|
||||||
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
|
||||||
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
|
|
||||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
|
||||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
|
||||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
|
||||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
|
||||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
|
||||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
|
||||||
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
|
||||||
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
|
|
||||||
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
|
|
||||||
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
||||||
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
|
||||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
|
||||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
|
||||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
||||||
# For decapoda-research/llama-*
|
|
||||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
||||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
||||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
|
||||||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
|
||||||
# transformers's mpt class has lower case
|
|
||||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
|
||||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
|
||||||
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
|
||||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
|
||||||
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
|
||||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
|
||||||
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
|
||||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
|
||||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
|
||||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
|
||||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
|
||||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
|
||||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
|
||||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
|
||||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
|
||||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
|
||||||
"Qwen2VLForConditionalGeneration":
|
|
||||||
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
|
||||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
|
||||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
|
||||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
|
||||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
|
||||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
|
||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
|
||||||
# NOTE: The below models are for speculative decoding only
|
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
|
||||||
"EAGLEModel": ("eagle", "EAGLE"),
|
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
|
||||||
}
|
|
||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
|
||||||
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
|
||||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
|
||||||
}
|
|
||||||
|
|
||||||
_MULTIMODAL_MODELS = {
|
|
||||||
"Blip2ForConditionalGeneration":
|
|
||||||
("blip2", "Blip2ForConditionalGeneration"),
|
|
||||||
"ChameleonForConditionalGeneration":
|
|
||||||
("chameleon", "ChameleonForConditionalGeneration"),
|
|
||||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
|
||||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
|
||||||
"LlavaForConditionalGeneration": ("llava",
|
|
||||||
"LlavaForConditionalGeneration"),
|
|
||||||
"LlavaNextForConditionalGeneration": ("llava_next",
|
|
||||||
"LlavaNextForConditionalGeneration"),
|
|
||||||
"LlavaNextVideoForConditionalGeneration":
|
|
||||||
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
|
||||||
"LlavaOnevisionForConditionalGeneration":
|
|
||||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
|
||||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
|
||||||
"PaliGemmaForConditionalGeneration": ("paligemma",
|
|
||||||
"PaliGemmaForConditionalGeneration"),
|
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
|
||||||
"PixtralForConditionalGeneration": ("pixtral",
|
|
||||||
"PixtralForConditionalGeneration"),
|
|
||||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
|
||||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
|
||||||
"Qwen2VLForConditionalGeneration"),
|
|
||||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
|
||||||
"MllamaForConditionalGeneration": ("mllama",
|
|
||||||
"MllamaForConditionalGeneration"),
|
|
||||||
}
|
|
||||||
_CONDITIONAL_GENERATION_MODELS = {
|
|
||||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
|
||||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
|
||||||
}
|
|
||||||
|
|
||||||
_MODELS = {
|
|
||||||
**_GENERATION_MODELS,
|
|
||||||
**_EMBEDDING_MODELS,
|
|
||||||
**_MULTIMODAL_MODELS,
|
|
||||||
**_CONDITIONAL_GENERATION_MODELS,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Architecture -> type.
|
|
||||||
# out of tree models
|
|
||||||
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
|
||||||
|
|
||||||
# Models not supported by ROCm.
|
|
||||||
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
|
||||||
|
|
||||||
# Models partially supported by ROCm.
|
|
||||||
# Architecture -> Reason.
|
|
||||||
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
|
||||||
"Triton flash attention. For half-precision SWA support, "
|
|
||||||
"please use CK flash attention by setting "
|
|
||||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
|
||||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
|
||||||
"Qwen2ForCausalLM":
|
|
||||||
_ROCM_SWA_REASON,
|
|
||||||
"MistralForCausalLM":
|
|
||||||
_ROCM_SWA_REASON,
|
|
||||||
"MixtralForCausalLM":
|
|
||||||
_ROCM_SWA_REASON,
|
|
||||||
"PaliGemmaForConditionalGeneration":
|
|
||||||
("ROCm flash attention does not yet "
|
|
||||||
"fully support 32-bit precision on PaliGemma"),
|
|
||||||
"Phi3VForCausalLM":
|
|
||||||
("ROCm Triton flash attention may run into compilation errors due to "
|
|
||||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
|
||||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
|
|
||||||
module_relname, cls_name = _MODELS[model_arch]
|
|
||||||
return f"vllm.model_executor.models.{module_relname}", cls_name
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache(maxsize=128)
|
|
||||||
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
||||||
if model_arch not in _MODELS:
|
|
||||||
return None
|
|
||||||
|
|
||||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
return getattr(module, cls_name, None)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
||||||
if model_arch in _OOT_MODELS:
|
|
||||||
return _OOT_MODELS[model_arch]
|
|
||||||
|
|
||||||
if is_hip():
|
|
||||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architecture {model_arch} is not supported by "
|
|
||||||
"ROCm for now.")
|
|
||||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
|
||||||
logger.warning(
|
|
||||||
"Model architecture %s is partially supported by ROCm: %s",
|
|
||||||
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
||||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
|
||||||
if model is not None:
|
|
||||||
return model
|
|
||||||
|
|
||||||
return ModelRegistry._try_get_model_stateful(model_arch)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def resolve_model_cls(
|
|
||||||
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
for arch in architectures:
|
|
||||||
model_cls = ModelRegistry._try_load_model_cls(arch)
|
|
||||||
if model_cls is not None:
|
|
||||||
return (model_cls, arch)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
|
||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_supported_archs() -> List[str]:
|
|
||||||
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
|
||||||
if model_arch in _MODELS:
|
|
||||||
logger.warning(
|
|
||||||
"Model architecture %s is already registered, and will be "
|
|
||||||
"overwritten by the new model class %s.", model_arch,
|
|
||||||
model_cls.__name__)
|
|
||||||
|
|
||||||
_OOT_MODELS[model_arch] = model_cls
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache(maxsize=128)
|
|
||||||
def _check_stateless(
|
|
||||||
func: Callable[[Type[nn.Module]], bool],
|
|
||||||
model_arch: str,
|
|
||||||
*,
|
|
||||||
default: Optional[bool] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run a boolean function against a model and return the result.
|
|
||||||
|
|
||||||
If the model is not found, returns the provided default value.
|
|
||||||
|
|
||||||
If the model is not already imported, the function is run inside a
|
|
||||||
subprocess to avoid initializing CUDA for the main program.
|
|
||||||
"""
|
|
||||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
|
||||||
if model is not None:
|
|
||||||
return func(model)
|
|
||||||
|
|
||||||
if model_arch not in _MODELS and default is not None:
|
|
||||||
return default
|
|
||||||
|
|
||||||
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
|
||||||
|
|
||||||
valid_name_characters = string.ascii_letters + string.digits + "._"
|
|
||||||
if any(s not in valid_name_characters for s in module_name):
|
|
||||||
raise ValueError(f"Unsafe module name detected for {model_arch}")
|
|
||||||
if any(s not in valid_name_characters for s in cls_name):
|
|
||||||
raise ValueError(f"Unsafe class name detected for {model_arch}")
|
|
||||||
if any(s not in valid_name_characters for s in func.__module__):
|
|
||||||
raise ValueError(f"Unsafe module name detected for {func}")
|
|
||||||
if any(s not in valid_name_characters for s in func.__name__):
|
|
||||||
raise ValueError(f"Unsafe class name detected for {func}")
|
|
||||||
|
|
||||||
err_id = uuid.uuid4()
|
|
||||||
|
|
||||||
stmts = ";".join([
|
|
||||||
f"from {module_name} import {cls_name}",
|
|
||||||
f"from {func.__module__} import {func.__name__}",
|
|
||||||
f"assert {func.__name__}({cls_name}), '{err_id}'",
|
|
||||||
])
|
|
||||||
|
|
||||||
result = subprocess.run([sys.executable, "-c", stmts],
|
|
||||||
capture_output=True)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
err_lines = [line.decode() for line in result.stderr.splitlines()]
|
|
||||||
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
|
|
||||||
err_str = "\n".join(err_lines)
|
|
||||||
raise RuntimeError(
|
|
||||||
"An unexpected error occurred while importing the model in "
|
|
||||||
f"another process. Error log:\n{err_str}")
|
|
||||||
|
|
||||||
return result.returncode == 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
return any(arch in _EMBEDDING_MODELS for arch in architectures)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
is_mm = partial(ModelRegistry._check_stateless,
|
|
||||||
supports_multimodal,
|
|
||||||
default=False)
|
|
||||||
|
|
||||||
return any(is_mm(arch) for arch in architectures)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
is_pp = partial(ModelRegistry._check_stateless,
|
|
||||||
supports_pp,
|
|
||||||
default=False)
|
|
||||||
|
|
||||||
return any(is_pp(arch) for arch in architectures)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelRegistry",
|
"ModelRegistry",
|
||||||
|
"HasInnerState",
|
||||||
|
"has_inner_state",
|
||||||
|
"SupportsLoRA",
|
||||||
|
"supports_lora",
|
||||||
|
"SupportsMultiModal",
|
||||||
|
"supports_multimodal",
|
||||||
|
"SupportsPP",
|
||||||
|
"supports_pp",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -25,20 +25,18 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
selective_scan_fn, selective_state_update)
|
selective_scan_fn, selective_state_update)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import HasInnerState
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||||
_get_graph_batch_size)
|
_get_graph_batch_size)
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA
|
from .interfaces import HasInnerState, SupportsLoRA
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|||||||
320
vllm/model_executor/models/registry.py
Normal file
320
vllm/model_executor/models/registry.py
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
import importlib
|
||||||
|
import string
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from functools import lru_cache, partial
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
from .interfaces import supports_multimodal, supports_pp
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_GENERATION_MODELS = {
|
||||||
|
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||||
|
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||||
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||||
|
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||||
|
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
||||||
|
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||||
|
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||||
|
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||||
|
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||||
|
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
||||||
|
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||||
|
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||||
|
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
||||||
|
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
|
||||||
|
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
|
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||||
|
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||||
|
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||||
|
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||||
|
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
||||||
|
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
|
||||||
|
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
|
||||||
|
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
||||||
|
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||||
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||||
|
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
# For decapoda-research/llama-*
|
||||||
|
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||||
|
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
||||||
|
# transformers's mpt class has lower case
|
||||||
|
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||||
|
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||||
|
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
||||||
|
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||||
|
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
||||||
|
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||||
|
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
||||||
|
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||||
|
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||||
|
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||||
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||||
|
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||||
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||||
|
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||||
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
|
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||||
|
"Qwen2VLForConditionalGeneration":
|
||||||
|
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
||||||
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||||
|
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||||
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
|
# NOTE: The below models are for speculative decoding only
|
||||||
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
|
"EAGLEModel": ("eagle", "EAGLE"),
|
||||||
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_EMBEDDING_MODELS = {
|
||||||
|
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||||
|
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_MULTIMODAL_MODELS = {
|
||||||
|
"Blip2ForConditionalGeneration":
|
||||||
|
("blip2", "Blip2ForConditionalGeneration"),
|
||||||
|
"ChameleonForConditionalGeneration":
|
||||||
|
("chameleon", "ChameleonForConditionalGeneration"),
|
||||||
|
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||||
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||||
|
"LlavaForConditionalGeneration": ("llava",
|
||||||
|
"LlavaForConditionalGeneration"),
|
||||||
|
"LlavaNextForConditionalGeneration": ("llava_next",
|
||||||
|
"LlavaNextForConditionalGeneration"),
|
||||||
|
"LlavaNextVideoForConditionalGeneration":
|
||||||
|
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||||
|
"LlavaOnevisionForConditionalGeneration":
|
||||||
|
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||||
|
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||||
|
"PaliGemmaForConditionalGeneration": ("paligemma",
|
||||||
|
"PaliGemmaForConditionalGeneration"),
|
||||||
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
|
"PixtralForConditionalGeneration": ("pixtral",
|
||||||
|
"PixtralForConditionalGeneration"),
|
||||||
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||||
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
||||||
|
"Qwen2VLForConditionalGeneration"),
|
||||||
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||||
|
"MllamaForConditionalGeneration": ("mllama",
|
||||||
|
"MllamaForConditionalGeneration"),
|
||||||
|
}
|
||||||
|
_CONDITIONAL_GENERATION_MODELS = {
|
||||||
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
|
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_MODELS = {
|
||||||
|
**_GENERATION_MODELS,
|
||||||
|
**_EMBEDDING_MODELS,
|
||||||
|
**_MULTIMODAL_MODELS,
|
||||||
|
**_CONDITIONAL_GENERATION_MODELS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Architecture -> type.
|
||||||
|
# out of tree models
|
||||||
|
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
||||||
|
|
||||||
|
# Models not supported by ROCm.
|
||||||
|
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
||||||
|
|
||||||
|
# Models partially supported by ROCm.
|
||||||
|
# Architecture -> Reason.
|
||||||
|
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
||||||
|
"Triton flash attention. For half-precision SWA support, "
|
||||||
|
"please use CK flash attention by setting "
|
||||||
|
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||||
|
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||||
|
"Qwen2ForCausalLM":
|
||||||
|
_ROCM_SWA_REASON,
|
||||||
|
"MistralForCausalLM":
|
||||||
|
_ROCM_SWA_REASON,
|
||||||
|
"MixtralForCausalLM":
|
||||||
|
_ROCM_SWA_REASON,
|
||||||
|
"PaliGemmaForConditionalGeneration":
|
||||||
|
("ROCm flash attention does not yet "
|
||||||
|
"fully support 32-bit precision on PaliGemma"),
|
||||||
|
"Phi3VForCausalLM":
|
||||||
|
("ROCm Triton flash attention may run into compilation errors due to "
|
||||||
|
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||||
|
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
|
||||||
|
module_relname, cls_name = _MODELS[model_arch]
|
||||||
|
return f"vllm.model_executor.models.{module_relname}", cls_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||||
|
if model_arch not in _MODELS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return getattr(module, cls_name, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||||
|
if model_arch in _OOT_MODELS:
|
||||||
|
return _OOT_MODELS[model_arch]
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architecture {model_arch} is not supported by "
|
||||||
|
"ROCm for now.")
|
||||||
|
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||||
|
logger.warning(
|
||||||
|
"Model architecture %s is partially supported by ROCm: %s",
|
||||||
|
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||||
|
model = ModelRegistry._try_get_model_stateless(model_arch)
|
||||||
|
if model is not None:
|
||||||
|
return model
|
||||||
|
|
||||||
|
return ModelRegistry._try_get_model_stateful(model_arch)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_model_cls(
|
||||||
|
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
if not architectures:
|
||||||
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = ModelRegistry._try_load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return (model_cls, arch)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_archs() -> List[str]:
|
||||||
|
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
||||||
|
if model_arch in _MODELS:
|
||||||
|
logger.warning(
|
||||||
|
"Model architecture %s is already registered, and will be "
|
||||||
|
"overwritten by the new model class %s.", model_arch,
|
||||||
|
model_cls.__name__)
|
||||||
|
|
||||||
|
_OOT_MODELS[model_arch] = model_cls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _check_stateless(
|
||||||
|
func: Callable[[Type[nn.Module]], bool],
|
||||||
|
model_arch: str,
|
||||||
|
*,
|
||||||
|
default: Optional[bool] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Run a boolean function against a model and return the result.
|
||||||
|
|
||||||
|
If the model is not found, returns the provided default value.
|
||||||
|
|
||||||
|
If the model is not already imported, the function is run inside a
|
||||||
|
subprocess to avoid initializing CUDA for the main program.
|
||||||
|
"""
|
||||||
|
model = ModelRegistry._try_get_model_stateless(model_arch)
|
||||||
|
if model is not None:
|
||||||
|
return func(model)
|
||||||
|
|
||||||
|
if model_arch not in _MODELS and default is not None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
||||||
|
|
||||||
|
valid_name_characters = string.ascii_letters + string.digits + "._"
|
||||||
|
if any(s not in valid_name_characters for s in module_name):
|
||||||
|
raise ValueError(f"Unsafe module name detected for {model_arch}")
|
||||||
|
if any(s not in valid_name_characters for s in cls_name):
|
||||||
|
raise ValueError(f"Unsafe class name detected for {model_arch}")
|
||||||
|
if any(s not in valid_name_characters for s in func.__module__):
|
||||||
|
raise ValueError(f"Unsafe module name detected for {func}")
|
||||||
|
if any(s not in valid_name_characters for s in func.__name__):
|
||||||
|
raise ValueError(f"Unsafe class name detected for {func}")
|
||||||
|
|
||||||
|
err_id = uuid.uuid4()
|
||||||
|
|
||||||
|
stmts = ";".join([
|
||||||
|
f"from {module_name} import {cls_name}",
|
||||||
|
f"from {func.__module__} import {func.__name__}",
|
||||||
|
f"assert {func.__name__}({cls_name}), '{err_id}'",
|
||||||
|
])
|
||||||
|
|
||||||
|
result = subprocess.run([sys.executable, "-c", stmts],
|
||||||
|
capture_output=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
err_lines = [line.decode() for line in result.stderr.splitlines()]
|
||||||
|
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
|
||||||
|
err_str = "\n".join(err_lines)
|
||||||
|
raise RuntimeError(
|
||||||
|
"An unexpected error occurred while importing the model in "
|
||||||
|
f"another process. Error log:\n{err_str}")
|
||||||
|
|
||||||
|
return result.returncode == 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
if not architectures:
|
||||||
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
|
return any(arch in _EMBEDDING_MODELS for arch in architectures)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
if not architectures:
|
||||||
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
|
is_mm = partial(ModelRegistry._check_stateless,
|
||||||
|
supports_multimodal,
|
||||||
|
default=False)
|
||||||
|
|
||||||
|
return any(is_mm(arch) for arch in architectures)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
if not architectures:
|
||||||
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
|
is_pp = partial(ModelRegistry._check_stateless,
|
||||||
|
supports_pp,
|
||||||
|
default=False)
|
||||||
|
|
||||||
|
return any(is_pp(arch) for arch in architectures)
|
||||||
@ -35,8 +35,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||||
supports_multimodal)
|
|
||||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||||
MultiModalInputs, MultiModalRegistry)
|
MultiModalInputs, MultiModalRegistry)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user