[Model] Add base class for LoRA-supported models (#5018)
This commit is contained in:
parent
d12af207d2
commit
96354d6a29
@ -4,6 +4,9 @@ Using LoRA adapters
|
|||||||
===================
|
===================
|
||||||
|
|
||||||
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
|
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
|
||||||
|
|
||||||
|
LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.
|
||||||
|
|
||||||
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
|
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
|
||||||
them locally with
|
them locally with
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from typing import List, Optional
|
|||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.types
|
||||||
|
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
@ -64,7 +65,7 @@ class LoRALayerWeights:
|
|||||||
output_dim: int,
|
output_dim: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.types.Device,
|
||||||
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
||||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
lora_a = torch.zeros([input_dim, rank],
|
lora_a = torch.zeros([input_dim, rank],
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
|
|||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||||
parse_fine_tuned_lora_name, replace_submodule)
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||||
from vllm.utils import LRUCache, is_pin_memory_available
|
from vllm.utils import LRUCache, is_pin_memory_available
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -363,7 +364,7 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: SupportsLoRA,
|
||||||
max_num_seqs: int,
|
max_num_seqs: int,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
@ -411,7 +412,7 @@ class LoRAModelManager:
|
|||||||
# embeddings_indices
|
# embeddings_indices
|
||||||
self.indices_len: List[Optional[int]] = [None] * 4
|
self.indices_len: List[Optional[int]] = [None] * 4
|
||||||
|
|
||||||
self.model: nn.Module = model
|
self.model = model
|
||||||
if hasattr(self.model, "supported_lora_modules"):
|
if hasattr(self.model, "supported_lora_modules"):
|
||||||
self.supported_lora_modules = copy.deepcopy(
|
self.supported_lora_modules = copy.deepcopy(
|
||||||
self.model.supported_lora_modules)
|
self.model.supported_lora_modules)
|
||||||
@ -428,7 +429,6 @@ class LoRAModelManager:
|
|||||||
self._active_loras: Dict[int, None] = {}
|
self._active_loras: Dict[int, None] = {}
|
||||||
self._last_mapping: Optional[LoRAMapping] = None
|
self._last_mapping: Optional[LoRAMapping] = None
|
||||||
self._create_lora_modules()
|
self._create_lora_modules()
|
||||||
self.model.lora_manager = self
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def capacity(self) -> int:
|
def capacity(self) -> int:
|
||||||
|
|||||||
@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||||
pt_weights_iterator, safetensors_weights_iterator)
|
pt_weights_iterator, safetensors_weights_iterator)
|
||||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||||
|
supports_vision)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.utils import is_tpu
|
from vllm.utils import is_tpu
|
||||||
|
|
||||||
@ -64,12 +65,15 @@ def _get_quantization_config(
|
|||||||
|
|
||||||
|
|
||||||
def _get_model_initialization_kwargs(
|
def _get_model_initialization_kwargs(
|
||||||
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
|
model_class: Type[nn.Module],
|
||||||
vision_language_config: Optional[VisionLanguageConfig]
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vlm_config: Optional[VisionLanguageConfig],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Get extra kwargs for model initialization."""
|
"""Get extra kwargs for model initialization."""
|
||||||
extra_kwargs: Dict[str, Any] = {}
|
extra_kwargs: Dict[str, Any] = {}
|
||||||
if hasattr(model_class, "supported_lora_modules"):
|
|
||||||
|
if supports_lora(model_class):
|
||||||
|
# lora_config=None is used to disable LoRA
|
||||||
extra_kwargs["lora_config"] = lora_config
|
extra_kwargs["lora_config"] = lora_config
|
||||||
elif lora_config:
|
elif lora_config:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -77,13 +81,15 @@ def _get_model_initialization_kwargs(
|
|||||||
"but LoRA is enabled. Support for this model may "
|
"but LoRA is enabled. Support for this model may "
|
||||||
"be added in the future. If this is important to you, "
|
"be added in the future. If this is important to you, "
|
||||||
"please open an issue on github.")
|
"please open an issue on github.")
|
||||||
elif issubclass(model_class, VisionLanguageModelBase):
|
|
||||||
if vision_language_config is None:
|
if supports_vision(model_class):
|
||||||
|
if vlm_config is None:
|
||||||
raise ValueError("Provide `image_input_type` and other vision "
|
raise ValueError("Provide `image_input_type` and other vision "
|
||||||
"related configurations through LLM entrypoint "
|
"related configurations through LLM entrypoint "
|
||||||
"or engine arguments.")
|
"or engine arguments.")
|
||||||
|
|
||||||
extra_kwargs["vision_language_config"] = vision_language_config
|
extra_kwargs["vlm_config"] = vlm_config
|
||||||
|
|
||||||
return extra_kwargs
|
return extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||||
@ -292,7 +294,9 @@ class BaiChuanModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanBaseForCausalLM(nn.Module):
|
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"W_pack": ["W_pack"],
|
"W_pack": ["W_pack"],
|
||||||
"gate_up_proj": [
|
"gate_up_proj": [
|
||||||
@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config: PretrainedConfig,
|
||||||
position_embedding: str,
|
position_embedding: str,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
||||||
quant_config)
|
quant_config)
|
||||||
|
|||||||
@ -28,6 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class GLMAttention(nn.Module):
|
class GLMAttention(nn.Module):
|
||||||
|
|
||||||
@ -322,7 +324,9 @@ class ChatGLMModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMForCausalLM(nn.Module):
|
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"query_key_value": ["query_key_value"],
|
"query_key_value": ["query_key_value"],
|
||||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||||
@ -345,7 +349,10 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config: ChatGLMConfig = config
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||||
8192)
|
8192)
|
||||||
|
|||||||
@ -26,7 +26,7 @@
|
|||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Optional[PretrainedConfig] = None,
|
config: LlamaConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
|||||||
@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -288,7 +290,9 @@ class GemmaModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class GemmaForCausalLM(nn.Module):
|
class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -319,9 +323,11 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
del lora_config # Unused.
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = GemmaModel(config, cache_config, quant_config)
|
self.model = GemmaModel(config, cache_config, quant_config)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|||||||
@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeAttention(nn.Module):
|
class GPTBigCodeAttention(nn.Module):
|
||||||
|
|
||||||
@ -230,7 +232,9 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeForCausalLM(nn.Module):
|
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||||
|
|
||||||
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
|
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
|
||||||
@ -250,7 +254,10 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
|
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
|
||||||
lora_config)
|
lora_config)
|
||||||
|
|||||||
130
vllm/model_executor/models/interfaces.py
Normal file
130
vllm/model_executor/models/interfaces.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
||||||
|
Union, overload, runtime_checkable)
|
||||||
|
|
||||||
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig, VisionLanguageConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsVision(Protocol):
|
||||||
|
"""The interface required for all vision language models (VLMs)."""
|
||||||
|
|
||||||
|
supports_vision: ClassVar[Literal[True]]
|
||||||
|
|
||||||
|
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||||
|
# so we need to treat the class as an instance and use isinstance instead
|
||||||
|
@runtime_checkable
|
||||||
|
class _SupportsVisionType(Protocol):
|
||||||
|
supports_vision: Literal[True]
|
||||||
|
|
||||||
|
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def supports_vision(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
|
||||||
|
if isinstance(model, type):
|
||||||
|
return isinstance(model, _SupportsVisionType)
|
||||||
|
|
||||||
|
return isinstance(model, SupportsVision)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsLoRA(Protocol):
|
||||||
|
"""The interface required for all models that support LoRA."""
|
||||||
|
|
||||||
|
supports_lora: ClassVar[Literal[True]]
|
||||||
|
|
||||||
|
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
|
||||||
|
supported_lora_modules: ClassVar[List[str]]
|
||||||
|
embedding_modules: ClassVar[Dict[str, str]]
|
||||||
|
embedding_padding_modules: ClassVar[List[str]]
|
||||||
|
|
||||||
|
# lora_config is None when LoRA is not enabled
|
||||||
|
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||||
|
# so we need to treat the class as an instance and use isinstance instead
|
||||||
|
@runtime_checkable
|
||||||
|
class _SupportsLoRAType(Protocol):
|
||||||
|
supports_lora: Literal[True]
|
||||||
|
|
||||||
|
packed_modules_mapping: Dict[str, List[str]]
|
||||||
|
supported_lora_modules: List[str]
|
||||||
|
embedding_modules: Dict[str, str]
|
||||||
|
embedding_padding_modules: List[str]
|
||||||
|
|
||||||
|
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def supports_lora(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
||||||
|
result = _supports_lora(model)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
lora_attrs = (
|
||||||
|
"packed_modules_mapping",
|
||||||
|
"supported_lora_modules",
|
||||||
|
"embedding_modules",
|
||||||
|
"embedding_padding_modules",
|
||||||
|
)
|
||||||
|
missing_attrs = tuple(attr for attr in lora_attrs
|
||||||
|
if not hasattr(model, attr))
|
||||||
|
|
||||||
|
if getattr(model, "supports_lora", False):
|
||||||
|
if missing_attrs:
|
||||||
|
logger.warning(
|
||||||
|
"The model (%s) sets `supports_lora=True`, "
|
||||||
|
"but is missing LoRA-specific attributes: %s",
|
||||||
|
model,
|
||||||
|
missing_attrs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not missing_attrs:
|
||||||
|
logger.warning(
|
||||||
|
"The model (%s) contains all LoRA-specific attributes, "
|
||||||
|
"but does not set `supports_lora=True`.", model)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_lora(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
||||||
|
if isinstance(model, type):
|
||||||
|
return isinstance(model, _SupportsLoRAType)
|
||||||
|
|
||||||
|
return isinstance(model, SupportsLoRA)
|
||||||
@ -49,6 +49,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.utils import is_hip, print_warning_once
|
from vllm.utils import is_hip, print_warning_once
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
|
|
||||||
@ -296,7 +298,9 @@ class LlamaModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -336,7 +340,10 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.model = LlamaModel(config,
|
self.model = LlamaModel(config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.image import get_dummy_image_data
|
from vllm.multimodal.image import get_dummy_image_data
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
from .vlm_base import VisionLanguageModelBase
|
from .interfaces import SupportsVision
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
"language_model.lm_head": "lm_head",
|
"language_model.lm_head": "lm_head",
|
||||||
@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
|||||||
@MULTIMODAL_REGISTRY.register_image_feature_input()
|
@MULTIMODAL_REGISTRY.register_image_feature_input()
|
||||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
||||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||||
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||||
|
|
||||||
|
supports_vision = True
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: LlavaConfig,
|
config: LlavaConfig,
|
||||||
vision_language_config: VisionLanguageConfig,
|
vlm_config: VisionLanguageConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
super().__init__(vision_language_config)
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.vlm_config = vlm_config
|
||||||
|
|
||||||
if self.vision_language_config.image_input_type == (
|
if self.vlm_config.image_input_type == (
|
||||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||||
else:
|
else:
|
||||||
@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
|
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
if list(data.shape[1:]) != list(
|
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
|
||||||
self.vision_language_config.image_input_shape[1:]):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The expected image tensor shape is batch dimension plus "
|
f"The expected image tensor shape is batch dimension plus "
|
||||||
f"{self.vision_language_config.image_input_shape[1:]}. "
|
f"{self.vlm_config.image_input_shape[1:]}. "
|
||||||
f"You supplied {data.shape}. "
|
f"You supplied {data.shape}. "
|
||||||
f"If you are using vLLM's entrypoint, make sure your "
|
f"If you are using vLLM's entrypoint, make sure your "
|
||||||
f"supplied image input is consistent with "
|
f"supplied image input is consistent with "
|
||||||
@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
image_features = kwargs.pop("image_features", None)
|
image_features = kwargs.pop("image_features", None)
|
||||||
|
|
||||||
expected_input_type = self.vision_language_config.image_input_type
|
expected_input_type = self.vlm_config.image_input_type
|
||||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||||
|
|
||||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||||
@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(
|
inputs_embeds = merge_vision_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.vision_language_config.image_token_id)
|
self.vlm_config.image_token_id)
|
||||||
|
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -25,8 +25,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
|||||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||||
from vllm.sequence import SamplerOutput, SequenceData
|
from vllm.sequence import SamplerOutput, SequenceData
|
||||||
|
|
||||||
|
from .interfaces import SupportsVision
|
||||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||||
from .vlm_base import VisionLanguageModelBase
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -106,19 +106,21 @@ def _image_pixel_processor(
|
|||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
|
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
|
||||||
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
||||||
class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||||
|
|
||||||
|
supports_vision = True
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: LlavaNextConfig,
|
config: LlavaNextConfig,
|
||||||
vision_language_config: VisionLanguageConfig,
|
vlm_config: VisionLanguageConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
super().__init__(vision_language_config)
|
super().__init__()
|
||||||
|
|
||||||
# Update the type annotation from that of its superclass
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.vlm_config = vlm_config
|
||||||
|
|
||||||
if self.vision_language_config.image_input_type == (
|
if self.vlm_config.image_input_type == (
|
||||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||||
else:
|
else:
|
||||||
@ -146,7 +148,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
torch.empty(config.text_config.hidden_size))
|
torch.empty(config.text_config.hidden_size))
|
||||||
|
|
||||||
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
_, num_channels, _, _ = self.vision_language_config.image_input_shape
|
_, num_channels, _, _ = self.vlm_config.image_input_shape
|
||||||
|
|
||||||
# Note that this is different from that of vLLM vision_language_config
|
# Note that this is different from that of vLLM vision_language_config
|
||||||
# since the image is resized by the HuggingFace preprocessor
|
# since the image is resized by the HuggingFace preprocessor
|
||||||
@ -177,7 +179,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
image_sizes = kwargs.pop("image_sizes", None)
|
image_sizes = kwargs.pop("image_sizes", None)
|
||||||
image_features = kwargs.pop("image_features", None)
|
image_features = kwargs.pop("image_features", None)
|
||||||
|
|
||||||
expected_input_type = self.vision_language_config.image_input_type
|
expected_input_type = self.vlm_config.image_input_type
|
||||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||||
|
|
||||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||||
@ -386,7 +388,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(
|
inputs_embeds = merge_vision_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.vision_language_config.image_token_id)
|
self.vlm_config.image_token_id)
|
||||||
|
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
@ -51,6 +52,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMMoE(nn.Module):
|
class MiniCPMMoE(nn.Module):
|
||||||
"""A tensor-parallel MoE implementation that shards each expert
|
"""A tensor-parallel MoE implementation that shards each expert
|
||||||
@ -388,7 +391,9 @@ class MiniCPMModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMForCausalLM(nn.Module):
|
class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -418,13 +423,16 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = MiniCPMModel(config,
|
self.model = MiniCPMModel(config,
|
||||||
|
|||||||
@ -54,6 +54,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||||
@ -472,7 +474,9 @@ class MixtralModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MixtralForCausalLM(nn.Module):
|
class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -504,7 +508,10 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.model = MixtralModel(config,
|
self.model = MixtralModel(config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from typing import Iterable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PhiConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
@ -59,11 +59,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class PhiAttention(nn.Module):
|
class PhiAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PhiConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -131,7 +133,7 @@ class PhiAttention(nn.Module):
|
|||||||
class PhiMLP(nn.Module):
|
class PhiMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PhiConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -160,7 +162,7 @@ class PhiMLP(nn.Module):
|
|||||||
class PhiLayer(nn.Module):
|
class PhiLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PhiConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -192,7 +194,7 @@ class PhiLayer(nn.Module):
|
|||||||
class PhiModel(nn.Module):
|
class PhiModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PhiConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -229,7 +231,9 @@ class PhiModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class PhiForCausalLM(nn.Module):
|
class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PhiConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
del lora_config # Unused.
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.model = PhiModel(config, cache_config, quant_config)
|
self.model = PhiModel(config, cache_config, quant_config)
|
||||||
|
|||||||
@ -48,6 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MLP(nn.Module):
|
class Qwen2MLP(nn.Module):
|
||||||
|
|
||||||
@ -263,7 +265,9 @@ class Qwen2Model(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForCausalLM(nn.Module):
|
class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -293,7 +297,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
del lora_config
|
|
||||||
# TODO (@robertgshaw2): see if this can be moved out
|
# TODO (@robertgshaw2): see if this can be moved out
|
||||||
if (cache_config.sliding_window is not None
|
if (cache_config.sliding_window is not None
|
||||||
and hasattr(config, "max_window_layers")):
|
and hasattr(config, "max_window_layers")):
|
||||||
@ -307,7 +310,10 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
))
|
))
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +0,0 @@
|
|||||||
from torch import nn
|
|
||||||
|
|
||||||
from vllm.config import VisionLanguageConfig
|
|
||||||
|
|
||||||
|
|
||||||
class VisionLanguageModelBase(nn.Module):
|
|
||||||
"""Base class for all vision language models (VLMs)."""
|
|
||||||
|
|
||||||
def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.vision_language_config = vision_language_config
|
|
||||||
@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
|
|
||||||
class XverseMLP(nn.Module):
|
class XverseMLP(nn.Module):
|
||||||
|
|
||||||
@ -266,7 +268,9 @@ class XverseModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class XverseForCausalLM(nn.Module):
|
class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -299,10 +303,13 @@ class XverseForCausalLM(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config=None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = XverseModel(config, cache_config, quant_config)
|
self.model = XverseModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
|||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
|
from vllm.model_executor.models.interfaces import supports_lora
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
@ -225,14 +226,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self.model_memory_usage / float(2**30))
|
self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
assert hasattr(self.model, "supported_lora_modules"
|
assert supports_lora(self.model), "Model does not support LoRA"
|
||||||
) and self.model.supported_lora_modules, (
|
|
||||||
"Model does not support LoRA")
|
|
||||||
assert hasattr(
|
|
||||||
self.model,
|
|
||||||
"embedding_modules"), "Model does not have embedding_modules"
|
|
||||||
assert hasattr(self.model, "embedding_padding_modules"
|
|
||||||
), "Model does not have embedding_padding_modules"
|
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
self.scheduler_config.max_num_seqs,
|
self.scheduler_config.max_num_seqs,
|
||||||
self.scheduler_config.max_num_batched_tokens,
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user