[Misc] Collect model support info in a single process per model (#9233)

This commit is contained in:
Cyrus Leung 2024-10-11 19:08:11 +08:00 committed by GitHub
parent cbc2ef5529
commit e808156f30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 236 additions and 167 deletions

View File

@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
5. Register your model
----------------------
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.
Finally, register your :code:`*ForCausalLM` class to the :code:`_VLLM_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.
6. Out-of-Tree Model Integration
--------------------------------------------

View File

@ -183,6 +183,8 @@ class EngineArgs:
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()

View File

@ -130,6 +130,9 @@ class MQLLMEngine:
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
engine_config = engine_args.create_engine_config()

View File

@ -3,8 +3,10 @@ import pickle
import subprocess
import sys
import tempfile
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import cloudpickle
import torch.nn as nn
@ -116,18 +118,13 @@ _SPECULATIVE_DECODING_MODELS = {
}
# yapf: enable
_MODELS = {
_VLLM_MODELS = {
**_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
}
# Architecture -> type or (module, class).
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
@ -154,79 +151,125 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
}
class ModelRegistry:
@dataclass(frozen=True)
class _ModelInfo:
is_text_generation_model: bool
is_embedding_model: bool
supports_multimodal: bool
supports_pp: bool
@staticmethod
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
if model_arch in _MODELS:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
return _ModelInfo(
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
)
if model_arch in _OOT_MODELS_LAZY:
return _OOT_MODELS_LAZY[model_arch]
raise KeyError(model_arch)
class _BaseRegisteredModel(ABC):
@abstractmethod
def inspect_model_cls(self) -> _ModelInfo:
raise NotImplementedError
@abstractmethod
def load_model_cls(self) -> Type[nn.Module]:
raise NotImplementedError
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has already been imported in the main process.
"""
interfaces: _ModelInfo
model_cls: Type[nn.Module]
@staticmethod
@lru_cache(maxsize=128)
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
try:
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
except KeyError:
return None
def from_model_cls(model_cls: Type[nn.Module]):
return _RegisteredModel(
interfaces=_ModelInfo.from_model_cls(model_cls),
model_cls=model_cls,
)
module = importlib.import_module(mod_name)
return getattr(module, cls_name, None)
def inspect_model_cls(self) -> _ModelInfo:
return self.interfaces
@staticmethod
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
def load_model_cls(self) -> Type[nn.Module]:
return self.model_cls
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has not been imported in the main process.
"""
module_name: str
class_name: str
# Performed in another process to avoid initializing CUDA
def inspect_model_cls(self) -> _ModelInfo:
return _run_in_subprocess(
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
def load_model_cls(self) -> Type[nn.Module]:
mod = importlib.import_module(self.module_name)
return getattr(mod, self.class_name)
@lru_cache(maxsize=128)
def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)
try:
return model.load_model_cls()
except Exception:
logger.exception("Error in loading model architecture '%s'",
model_arch)
return None
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return model
return ModelRegistry._try_get_model_stateful(model_arch)
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[_ModelInfo]:
try:
return model.inspect_model_cls()
except Exception:
logger.exception("Error in inspecting model architecture '%s'",
model_arch)
return None
@staticmethod
def resolve_model_cls(
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_supported_archs(self) -> List[str]:
return list(self.models.keys())
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Union[Type[nn.Module],
str]):
def register_model(
self,
model_arch: str,
model_cls: Union[Type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
@ -238,7 +281,7 @@ class ModelRegistry:
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if model_arch in _MODELS:
if model_arch in self.models:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
@ -250,120 +293,141 @@ class ModelRegistry:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
module_name, cls_name = split_str
_OOT_MODELS_LAZY[model_arch] = module_name, cls_name
model = _LazyRegisteredModel(*split_str)
else:
_OOT_MODELS[model_arch] = model_cls
model = _RegisteredModel.from_model_cls(model_cls)
@staticmethod
@lru_cache(maxsize=128)
def _check_stateless(
func: Callable[[Type[nn.Module]], bool],
model_arch: str,
*,
default: Optional[bool] = None,
self.models[model_arch] = model
def _raise_for_unsupported(self, architectures: List[str]):
all_supported_archs = self.get_supported_archs()
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}")
def _try_load_model_cls(self,
model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in self.models:
return None
return _try_load_model_cls(model_arch, self.models[model_arch])
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
if model_arch not in self.models:
return None
return _try_inspect_model_cls(model_arch, self.models[model_arch])
def _normalize_archs(
self,
architectures: Union[str, List[str]],
) -> List[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return architectures
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> _ModelInfo:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return model_info
return self._raise_for_unsupported(architectures)
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[Type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
def is_text_generation_model(
self,
architectures: Union[str, List[str]],
) -> bool:
"""
Run a boolean function against a model and return the result.
return self.inspect_model_cls(architectures).is_text_generation_model
If the model is not found, returns the provided default value.
def is_embedding_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).is_embedding_model
If the model is not already imported, the function is run inside a
subprocess to avoid initializing CUDA for the main program.
"""
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return func(model)
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
ModelRegistry = _ModelRegistry({
model_arch: _LazyRegisteredModel(
module_name=f"vllm.model_executor.models.{mod_relname}",
class_name=cls_name,
)
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
})
_T = TypeVar("_T")
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
with tempfile.NamedTemporaryFile() as output_file:
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps((fn, output_file.name))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(
[sys.executable, "-m", "vllm.model_executor.models.registry"],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
except KeyError:
if default is not None:
return default
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n"
f"{returned.stderr.decode()}") from e
raise
with open(output_file.name, "rb") as f:
return pickle.load(f)
with tempfile.NamedTemporaryFile() as output_file:
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps(
(mod_name, cls_name, func, output_file.name))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(
[sys.executable, "-m", "vllm.model_executor.models.registry"],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error happened when testing "
f"model support for{mod_name}.{cls_name}:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
return result
def _run() -> None:
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
@staticmethod
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
fn, output_file = pickle.loads(sys.stdin.buffer.read())
is_txt_gen = partial(ModelRegistry._check_stateless,
is_text_generation_model,
default=False)
result = fn()
return any(is_txt_gen(arch) for arch in architectures)
@staticmethod
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_emb = partial(ModelRegistry._check_stateless,
is_embedding_model,
default=False)
return any(is_emb(arch) for arch in architectures)
@staticmethod
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_mm = partial(ModelRegistry._check_stateless,
supports_multimodal,
default=False)
return any(is_mm(arch) for arch in architectures)
@staticmethod
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_pp = partial(ModelRegistry._check_stateless,
supports_pp,
default=False)
return any(is_pp(arch) for arch in architectures)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
if __name__ == "__main__":
(mod_name, cls_name, func,
output_file) = pickle.loads(sys.stdin.buffer.read())
mod = importlib.import_module(mod_name)
klass = getattr(mod, cls_name)
result = func(klass)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
_run()