[Misc] Collect model support info in a single process per model (#9233)
This commit is contained in:
parent
cbc2ef5529
commit
e808156f30
@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
|
|||||||
5. Register your model
|
5. Register your model
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.
|
Finally, register your :code:`*ForCausalLM` class to the :code:`_VLLM_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.
|
||||||
|
|
||||||
6. Out-of-Tree Model Integration
|
6. Out-of-Tree Model Integration
|
||||||
--------------------------------------------
|
--------------------------------------------
|
||||||
|
|||||||
@ -183,6 +183,8 @@ class EngineArgs:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
self.tokenizer = self.model
|
self.tokenizer = self.model
|
||||||
|
|
||||||
|
# Setup plugins
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
load_general_plugins()
|
load_general_plugins()
|
||||||
|
|
||||||
|
|||||||
@ -130,6 +130,9 @@ class MQLLMEngine:
|
|||||||
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
||||||
usage_context: UsageContext, ipc_path: str):
|
usage_context: UsageContext, ipc_path: str):
|
||||||
"""Creates an MQLLMEngine from the engine arguments."""
|
"""Creates an MQLLMEngine from the engine arguments."""
|
||||||
|
# Setup plugins for each process
|
||||||
|
from vllm.plugins import load_general_plugins
|
||||||
|
load_general_plugins()
|
||||||
|
|
||||||
engine_config = engine_args.create_engine_config()
|
engine_config = engine_args.create_engine_config()
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,10 @@ import pickle
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from functools import lru_cache, partial
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
from dataclasses import dataclass, field
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -116,18 +118,13 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
}
|
}
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
_MODELS = {
|
_VLLM_MODELS = {
|
||||||
**_TEXT_GENERATION_MODELS,
|
**_TEXT_GENERATION_MODELS,
|
||||||
**_EMBEDDING_MODELS,
|
**_EMBEDDING_MODELS,
|
||||||
**_MULTIMODAL_MODELS,
|
**_MULTIMODAL_MODELS,
|
||||||
**_SPECULATIVE_DECODING_MODELS,
|
**_SPECULATIVE_DECODING_MODELS,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Architecture -> type or (module, class).
|
|
||||||
# out of tree models
|
|
||||||
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
|
||||||
_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {}
|
|
||||||
|
|
||||||
# Models not supported by ROCm.
|
# Models not supported by ROCm.
|
||||||
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
||||||
|
|
||||||
@ -154,79 +151,125 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
@dataclass(frozen=True)
|
||||||
|
class _ModelInfo:
|
||||||
|
is_text_generation_model: bool
|
||||||
|
is_embedding_model: bool
|
||||||
|
supports_multimodal: bool
|
||||||
|
supports_pp: bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
|
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||||
if model_arch in _MODELS:
|
return _ModelInfo(
|
||||||
module_relname, cls_name = _MODELS[model_arch]
|
is_text_generation_model=is_text_generation_model(model),
|
||||||
return f"vllm.model_executor.models.{module_relname}", cls_name
|
is_embedding_model=is_embedding_model(model),
|
||||||
|
supports_multimodal=supports_multimodal(model),
|
||||||
|
supports_pp=supports_pp(model),
|
||||||
|
)
|
||||||
|
|
||||||
if model_arch in _OOT_MODELS_LAZY:
|
|
||||||
return _OOT_MODELS_LAZY[model_arch]
|
|
||||||
|
|
||||||
raise KeyError(model_arch)
|
class _BaseRegisteredModel(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def inspect_model_cls(self) -> _ModelInfo:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_cls(self) -> Type[nn.Module]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _RegisteredModel(_BaseRegisteredModel):
|
||||||
|
"""
|
||||||
|
Represents a model that has already been imported in the main process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
interfaces: _ModelInfo
|
||||||
|
model_cls: Type[nn.Module]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=128)
|
def from_model_cls(model_cls: Type[nn.Module]):
|
||||||
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
|
return _RegisteredModel(
|
||||||
try:
|
interfaces=_ModelInfo.from_model_cls(model_cls),
|
||||||
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
model_cls=model_cls,
|
||||||
except KeyError:
|
)
|
||||||
return None
|
|
||||||
|
|
||||||
module = importlib.import_module(mod_name)
|
def inspect_model_cls(self) -> _ModelInfo:
|
||||||
return getattr(module, cls_name, None)
|
return self.interfaces
|
||||||
|
|
||||||
@staticmethod
|
def load_model_cls(self) -> Type[nn.Module]:
|
||||||
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
|
return self.model_cls
|
||||||
if model_arch in _OOT_MODELS:
|
|
||||||
return _OOT_MODELS[model_arch]
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||||
|
"""
|
||||||
|
Represents a model that has not been imported in the main process.
|
||||||
|
"""
|
||||||
|
module_name: str
|
||||||
|
class_name: str
|
||||||
|
|
||||||
|
# Performed in another process to avoid initializing CUDA
|
||||||
|
def inspect_model_cls(self) -> _ModelInfo:
|
||||||
|
return _run_in_subprocess(
|
||||||
|
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
|
||||||
|
|
||||||
|
def load_model_cls(self) -> Type[nn.Module]:
|
||||||
|
mod = importlib.import_module(self.module_name)
|
||||||
|
return getattr(mod, self.class_name)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _try_load_model_cls(
|
||||||
|
model_arch: str,
|
||||||
|
model: _BaseRegisteredModel,
|
||||||
|
) -> Optional[Type[nn.Module]]:
|
||||||
if is_hip():
|
if is_hip():
|
||||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||||
raise ValueError(
|
raise ValueError(f"Model architecture '{model_arch}' is not "
|
||||||
f"Model architecture {model_arch} is not supported by "
|
"supported by ROCm for now.")
|
||||||
"ROCm for now.")
|
|
||||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
|
||||||
logger.warning(
|
|
||||||
"Model architecture %s is partially supported by ROCm: %s",
|
|
||||||
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
|
||||||
|
|
||||||
|
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||||
|
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
|
||||||
|
logger.warning(
|
||||||
|
"Model architecture '%s' is partially "
|
||||||
|
"supported by ROCm: %s", model_arch, msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return model.load_model_cls()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in loading model architecture '%s'",
|
||||||
|
model_arch)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
||||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
|
||||||
if model is not None:
|
|
||||||
return model
|
|
||||||
|
|
||||||
return ModelRegistry._try_get_model_stateful(model_arch)
|
@lru_cache(maxsize=128)
|
||||||
|
def _try_inspect_model_cls(
|
||||||
|
model_arch: str,
|
||||||
|
model: _BaseRegisteredModel,
|
||||||
|
) -> Optional[_ModelInfo]:
|
||||||
|
try:
|
||||||
|
return model.inspect_model_cls()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in inspecting model architecture '%s'",
|
||||||
|
model_arch)
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def resolve_model_cls(
|
|
||||||
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
for arch in architectures:
|
@dataclass
|
||||||
model_cls = ModelRegistry._try_load_model_cls(arch)
|
class _ModelRegistry:
|
||||||
if model_cls is not None:
|
# Keyed by model_arch
|
||||||
return (model_cls, arch)
|
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
|
||||||
|
|
||||||
raise ValueError(
|
def get_supported_archs(self) -> List[str]:
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
return list(self.models.keys())
|
||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
|
||||||
|
|
||||||
@staticmethod
|
def register_model(
|
||||||
def get_supported_archs() -> List[str]:
|
self,
|
||||||
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
|
model_arch: str,
|
||||||
|
model_cls: Union[Type[nn.Module], str],
|
||||||
@staticmethod
|
) -> None:
|
||||||
def register_model(model_arch: str, model_cls: Union[Type[nn.Module],
|
|
||||||
str]):
|
|
||||||
"""
|
"""
|
||||||
Register an external model to be used in vLLM.
|
Register an external model to be used in vLLM.
|
||||||
|
|
||||||
@ -238,7 +281,7 @@ class ModelRegistry:
|
|||||||
when importing the model and thus the related error
|
when importing the model and thus the related error
|
||||||
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
||||||
"""
|
"""
|
||||||
if model_arch in _MODELS:
|
if model_arch in self.models:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Model architecture %s is already registered, and will be "
|
"Model architecture %s is already registered, and will be "
|
||||||
"overwritten by the new model class %s.", model_arch,
|
"overwritten by the new model class %s.", model_arch,
|
||||||
@ -250,43 +293,110 @@ class ModelRegistry:
|
|||||||
msg = "Expected a string in the format `<module>:<class>`"
|
msg = "Expected a string in the format `<module>:<class>`"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
module_name, cls_name = split_str
|
model = _LazyRegisteredModel(*split_str)
|
||||||
_OOT_MODELS_LAZY[model_arch] = module_name, cls_name
|
|
||||||
else:
|
else:
|
||||||
_OOT_MODELS[model_arch] = model_cls
|
model = _RegisteredModel.from_model_cls(model_cls)
|
||||||
|
|
||||||
@staticmethod
|
self.models[model_arch] = model
|
||||||
@lru_cache(maxsize=128)
|
|
||||||
def _check_stateless(
|
def _raise_for_unsupported(self, architectures: List[str]):
|
||||||
func: Callable[[Type[nn.Module]], bool],
|
all_supported_archs = self.get_supported_archs()
|
||||||
model_arch: str,
|
|
||||||
*,
|
raise ValueError(
|
||||||
default: Optional[bool] = None,
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {all_supported_archs}")
|
||||||
|
|
||||||
|
def _try_load_model_cls(self,
|
||||||
|
model_arch: str) -> Optional[Type[nn.Module]]:
|
||||||
|
if model_arch not in self.models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _try_load_model_cls(model_arch, self.models[model_arch])
|
||||||
|
|
||||||
|
def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]:
|
||||||
|
if model_arch not in self.models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _try_inspect_model_cls(model_arch, self.models[model_arch])
|
||||||
|
|
||||||
|
def _normalize_archs(
|
||||||
|
self,
|
||||||
|
architectures: Union[str, List[str]],
|
||||||
|
) -> List[str]:
|
||||||
|
if isinstance(architectures, str):
|
||||||
|
architectures = [architectures]
|
||||||
|
if not architectures:
|
||||||
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
|
return architectures
|
||||||
|
|
||||||
|
def inspect_model_cls(
|
||||||
|
self,
|
||||||
|
architectures: Union[str, List[str]],
|
||||||
|
) -> _ModelInfo:
|
||||||
|
architectures = self._normalize_archs(architectures)
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
model_info = self._try_inspect_model_cls(arch)
|
||||||
|
if model_info is not None:
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
return self._raise_for_unsupported(architectures)
|
||||||
|
|
||||||
|
def resolve_model_cls(
|
||||||
|
self,
|
||||||
|
architectures: Union[str, List[str]],
|
||||||
|
) -> Tuple[Type[nn.Module], str]:
|
||||||
|
architectures = self._normalize_archs(architectures)
|
||||||
|
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = self._try_load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return (model_cls, arch)
|
||||||
|
|
||||||
|
return self._raise_for_unsupported(architectures)
|
||||||
|
|
||||||
|
def is_text_generation_model(
|
||||||
|
self,
|
||||||
|
architectures: Union[str, List[str]],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
return self.inspect_model_cls(architectures).is_text_generation_model
|
||||||
Run a boolean function against a model and return the result.
|
|
||||||
|
|
||||||
If the model is not found, returns the provided default value.
|
def is_embedding_model(
|
||||||
|
self,
|
||||||
|
architectures: Union[str, List[str]],
|
||||||
|
) -> bool:
|
||||||
|
return self.inspect_model_cls(architectures).is_embedding_model
|
||||||
|
|
||||||
If the model is not already imported, the function is run inside a
|
def is_multimodal_model(
|
||||||
subprocess to avoid initializing CUDA for the main program.
|
self,
|
||||||
"""
|
architectures: Union[str, List[str]],
|
||||||
model = ModelRegistry._try_get_model_stateless(model_arch)
|
) -> bool:
|
||||||
if model is not None:
|
return self.inspect_model_cls(architectures).supports_multimodal
|
||||||
return func(model)
|
|
||||||
|
|
||||||
try:
|
def is_pp_supported_model(
|
||||||
mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
|
self,
|
||||||
except KeyError:
|
architectures: Union[str, List[str]],
|
||||||
if default is not None:
|
) -> bool:
|
||||||
return default
|
return self.inspect_model_cls(architectures).supports_pp
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
ModelRegistry = _ModelRegistry({
|
||||||
|
model_arch: _LazyRegisteredModel(
|
||||||
|
module_name=f"vllm.model_executor.models.{mod_relname}",
|
||||||
|
class_name=cls_name,
|
||||||
|
)
|
||||||
|
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
|
||||||
with tempfile.NamedTemporaryFile() as output_file:
|
with tempfile.NamedTemporaryFile() as output_file:
|
||||||
# `cloudpickle` allows pickling lambda functions directly
|
# `cloudpickle` allows pickling lambda functions directly
|
||||||
input_bytes = cloudpickle.dumps(
|
input_bytes = cloudpickle.dumps((fn, output_file.name))
|
||||||
(mod_name, cls_name, func, output_file.name))
|
|
||||||
# cannot use `sys.executable __file__` here because the script
|
# cannot use `sys.executable __file__` here because the script
|
||||||
# contains relative imports
|
# contains relative imports
|
||||||
returned = subprocess.run(
|
returned = subprocess.run(
|
||||||
@ -299,71 +409,25 @@ class ModelRegistry:
|
|||||||
returned.check_returncode()
|
returned.check_returncode()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# wrap raised exception to provide more information
|
# wrap raised exception to provide more information
|
||||||
raise RuntimeError(f"Error happened when testing "
|
raise RuntimeError(f"Error raised in subprocess:\n"
|
||||||
f"model support for{mod_name}.{cls_name}:\n"
|
|
||||||
f"{returned.stderr.decode()}") from e
|
f"{returned.stderr.decode()}") from e
|
||||||
|
|
||||||
with open(output_file.name, "rb") as f:
|
with open(output_file.name, "rb") as f:
|
||||||
result = pickle.load(f)
|
return pickle.load(f)
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
is_txt_gen = partial(ModelRegistry._check_stateless,
|
def _run() -> None:
|
||||||
is_text_generation_model,
|
# Setup plugins
|
||||||
default=False)
|
from vllm.plugins import load_general_plugins
|
||||||
|
load_general_plugins()
|
||||||
|
|
||||||
return any(is_txt_gen(arch) for arch in architectures)
|
fn, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||||
|
|
||||||
@staticmethod
|
result = fn()
|
||||||
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
is_emb = partial(ModelRegistry._check_stateless,
|
with open(output_file, "wb") as f:
|
||||||
is_embedding_model,
|
f.write(pickle.dumps(result))
|
||||||
default=False)
|
|
||||||
|
|
||||||
return any(is_emb(arch) for arch in architectures)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
is_mm = partial(ModelRegistry._check_stateless,
|
|
||||||
supports_multimodal,
|
|
||||||
default=False)
|
|
||||||
|
|
||||||
return any(is_mm(arch) for arch in architectures)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
|
|
||||||
if isinstance(architectures, str):
|
|
||||||
architectures = [architectures]
|
|
||||||
if not architectures:
|
|
||||||
logger.warning("No model architectures are specified")
|
|
||||||
|
|
||||||
is_pp = partial(ModelRegistry._check_stateless,
|
|
||||||
supports_pp,
|
|
||||||
default=False)
|
|
||||||
|
|
||||||
return any(is_pp(arch) for arch in architectures)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
(mod_name, cls_name, func,
|
_run()
|
||||||
output_file) = pickle.loads(sys.stdin.buffer.read())
|
|
||||||
mod = importlib.import_module(mod_name)
|
|
||||||
klass = getattr(mod, cls_name)
|
|
||||||
result = func(klass)
|
|
||||||
with open(output_file, "wb") as f:
|
|
||||||
f.write(pickle.dumps(result))
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user