[Model] Support Gemma2 embedding model (#9004)
This commit is contained in:
parent
53b3a33027
commit
15986f598c
@ -277,6 +277,7 @@ class HfRunner:
|
|||||||
SentenceTransformer(
|
SentenceTransformer(
|
||||||
model_name,
|
model_name,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
|
trust_remote_code=True,
|
||||||
).to(dtype=torch_dtype))
|
).to(dtype=torch_dtype))
|
||||||
else:
|
else:
|
||||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||||
|
|
||||||
Run `pytest tests/models/test_llama_embedding.py`.
|
Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"intfloat/e5-mistral-7b-instruct",
|
"intfloat/e5-mistral-7b-instruct",
|
||||||
|
"BAAI/bge-multilingual-gemma2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -28,6 +29,14 @@ def test_models(
|
|||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# The example_prompts has ending "\n", for example:
|
||||||
|
# "Write a short story about a robot that dreams for the first time.\n"
|
||||||
|
# sentence_transformers will strip the input texts, see:
|
||||||
|
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
|
||||||
|
# This makes the input_ids different between hf_model and vllm_model.
|
||||||
|
# So we need to strip the input texts to avoid test failing.
|
||||||
|
example_prompts = [str(s).strip() for s in example_prompts]
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
|
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
|
||||||
hf_outputs = hf_model.encode(example_prompts)
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
|
||||||
|
|||||||
@ -278,11 +278,14 @@ class Gemma2Model(nn.Module):
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
hidden_states *= self.normalizer
|
hidden_states *= self.normalizer
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
|
|||||||
82
vllm/model_executor/models/gemma2_embedding.py
Normal file
82
vllm/model_executor/models/gemma2_embedding.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma2EmbeddingModel(nn.Module):
|
||||||
|
"""A model that uses Gemma2 with additional embedding functionalities.
|
||||||
|
|
||||||
|
This class encapsulates the Gemma2Model and provides an interface for
|
||||||
|
embedding operations and customized pooling functions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: An instance of Gemma2Model used for forward operations.
|
||||||
|
_pooler: An instance of Pooler used for pooling operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = Gemma2Model(**kwargs)
|
||||||
|
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model.forward(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.model.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
@ -83,6 +83,7 @@ _GENERATION_MODELS = {
|
|||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||||
|
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_MULTIMODAL_MODELS = {
|
_MULTIMODAL_MODELS = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user