[Model] Support Gemma2 embedding model (#9004)
This commit is contained in:
parent
53b3a33027
commit
15986f598c
@ -277,6 +277,7 @@ class HfRunner:
|
||||
SentenceTransformer(
|
||||
model_name,
|
||||
device="cpu",
|
||||
trust_remote_code=True,
|
||||
).to(dtype=torch_dtype))
|
||||
else:
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_llama_embedding.py`.
|
||||
Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
MODELS = [
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
"BAAI/bge-multilingual-gemma2",
|
||||
]
|
||||
|
||||
|
||||
@ -28,6 +29,14 @@ def test_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
# The example_prompts has ending "\n", for example:
|
||||
# "Write a short story about a robot that dreams for the first time.\n"
|
||||
# sentence_transformers will strip the input texts, see:
|
||||
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
|
||||
# This makes the input_ids different between hf_model and vllm_model.
|
||||
# So we need to strip the input texts to avoid test failing.
|
||||
example_prompts = [str(s).strip() for s in example_prompts]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
|
||||
@ -278,11 +278,14 @@ class Gemma2Model(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states *= self.normalizer
|
||||
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
82
vllm/model_executor/models/gemma2_embedding.py
Normal file
82
vllm/model_executor/models/gemma2_embedding.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
|
||||
class Gemma2EmbeddingModel(nn.Module):
|
||||
"""A model that uses Gemma2 with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the Gemma2Model and provides an interface for
|
||||
embedding operations and customized pooling functions.
|
||||
|
||||
Attributes:
|
||||
model: An instance of Gemma2Model used for forward operations.
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = Gemma2Model(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.forward(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.model.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
@ -83,6 +83,7 @@ _GENERATION_MODELS = {
|
||||
_EMBEDDING_MODELS = {
|
||||
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user