From 518369d78c1ec9ffef308131366e4bda745b5573 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 12 Dec 2023 22:21:45 -0800 Subject: [PATCH] Implement lazy model loader (#2044) --- vllm/model_executor/model_loader.py | 62 ++----------- vllm/model_executor/models/__init__.py | 115 +++++++++++++++++-------- vllm/model_executor/models/mixtral.py | 13 +-- 3 files changed, 89 insertions(+), 101 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index e7bd7548..37543d8c 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -7,54 +7,9 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config import ModelConfig -from vllm.model_executor.models import * +from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) -from vllm.utils import is_hip -from vllm.logger import init_logger - -logger = init_logger(__name__) - -# TODO(woosuk): Lazy-load the model classes. -_MODEL_REGISTRY = { - "AquilaModel": AquilaForCausalLM, - "AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2 - "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b - "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b - "BloomForCausalLM": BloomForCausalLM, - "ChatGLMModel": ChatGLMForCausalLM, - "ChatGLMForConditionalGeneration": ChatGLMForCausalLM, - "FalconForCausalLM": FalconForCausalLM, - "GPT2LMHeadModel": GPT2LMHeadModel, - "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, - "GPTJForCausalLM": GPTJForCausalLM, - "GPTNeoXForCausalLM": GPTNeoXForCausalLM, - "InternLMForCausalLM": InternLMForCausalLM, - "LlamaForCausalLM": LlamaForCausalLM, - "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* - "MistralForCausalLM": MistralForCausalLM, - "MixtralForCausalLM": MixtralForCausalLM, - # transformers's mpt class has lower case - "MptForCausalLM": MPTForCausalLM, - "MPTForCausalLM": MPTForCausalLM, - "OPTForCausalLM": OPTForCausalLM, - "PhiForCausalLM": PhiForCausalLM, - "QWenLMHeadModel": QWenLMHeadModel, - "RWForCausalLM": FalconForCausalLM, - "YiForCausalLM": YiForCausalLM, -} - -# Models to be disabled in ROCm -_ROCM_UNSUPPORTED_MODELS = [] -if is_hip(): - for rocm_model in _ROCM_UNSUPPORTED_MODELS: - del _MODEL_REGISTRY[rocm_model] - -# Models partially supported in ROCm -_ROCM_PARTIALLY_SUPPORTED_MODELS = { - "MistralForCausalLM": - "Sliding window attention is not supported in ROCm's flash attention", -} @contextlib.contextmanager @@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: - if arch in _MODEL_REGISTRY: - if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: - logger.warning( - f"{arch} is not fully supported in ROCm. Reason: " - f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}") - return _MODEL_REGISTRY[arch] - elif arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError( - f"Model architecture {arch} is not supported by ROCm for now. \n" - f"Supported architectures {list(_MODEL_REGISTRY.keys())}") + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls raise ValueError( f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") + f"Supported architectures: {ModelRegistry.get_supported_archs()}") def get_model(model_config: ModelConfig) -> nn.Module: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 28a0aa77..5596884f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,41 +1,80 @@ -from vllm.model_executor.models.aquila import AquilaForCausalLM -from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM, - BaichuanForCausalLM) -from vllm.model_executor.models.bloom import BloomForCausalLM -from vllm.model_executor.models.falcon import FalconForCausalLM -from vllm.model_executor.models.gpt2 import GPT2LMHeadModel -from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM -from vllm.model_executor.models.gpt_j import GPTJForCausalLM -from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM -from vllm.model_executor.models.internlm import InternLMForCausalLM -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.mistral import MistralForCausalLM -from vllm.model_executor.models.mixtral import MixtralForCausalLM -from vllm.model_executor.models.mpt import MPTForCausalLM -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.models.phi_1_5 import PhiForCausalLM -from vllm.model_executor.models.qwen import QWenLMHeadModel -from vllm.model_executor.models.chatglm import ChatGLMForCausalLM -from vllm.model_executor.models.yi import YiForCausalLM +import importlib +from typing import List, Optional, Type + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_hip + +logger = init_logger(__name__) + +# Architecture -> (module, class). +_MODELS = { + "AquilaModel": ("aquila", "AquilaForCausalLM"), + "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralForCausalLM": ("mistral", "MistralForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "YiForCausalLM": ("yi", "YiForCausalLM"), +} + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_PARTIALLY_SUPPORTED_MODELS = { + "MistralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", +} + + +class ModelRegistry: + + @staticmethod + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in _MODELS: + return None + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "ROCm for now.") + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + logger.warning( + f"Model architecture {model_arch} is partially supported " + "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + + module_name, model_cls_name = _MODELS[model_arch] + module = importlib.import_module( + f"vllm.model_executor.models.{module_name}") + return getattr(module, model_cls_name, None) + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + __all__ = [ - "AquilaForCausalLM", - "BaiChuanForCausalLM", - "BaichuanForCausalLM", - "BloomForCausalLM", - "ChatGLMForCausalLM", - "FalconForCausalLM", - "GPT2LMHeadModel", - "GPTBigCodeForCausalLM", - "GPTJForCausalLM", - "GPTNeoXForCausalLM", - "InternLMForCausalLM", - "LlamaForCausalLM", - "MPTForCausalLM", - "OPTForCausalLM", - "PhiForCausalLM", - "QWenLMHeadModel", - "MistralForCausalLM", - "MixtralForCausalLM", - "YiForCausalLM", + "ModelRegistry", ] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3021ced8..8e0a094c 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -33,14 +33,15 @@ from transformers import MixtralConfig try: import megablocks.ops as ops -except ImportError: - print( - "MegaBlocks not found. Please install it by `pip install megablocks`.") +except ImportError as e: + raise ImportError("MegaBlocks not found. " + "Please install it by `pip install megablocks`.") from e try: import stk -except ImportError: - print( - "STK not found: please see https://github.com/stanford-futuredata/stk") +except ImportError as e: + raise ImportError( + "STK not found. " + "Please install it by `pip install stanford-stk`.") from e from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention