Support Roberta embedding models (#9387)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser 2024-11-14 18:23:29 -03:00 committed by GitHub
parent 1dbae0329c
commit 4a18fd14ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 202 additions and 14 deletions

View File

@ -98,6 +98,9 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the // NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this // head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16. // to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64: case 64:
LAUNCH_PAGED_ATTENTION_V1(64); LAUNCH_PAGED_ATTENTION_V1(64);
break; break;

View File

@ -104,6 +104,9 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the // NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this // head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16. // to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64: case 64:
LAUNCH_PAGED_ATTENTION_V2(64); LAUNCH_PAGED_ATTENTION_V2(64);
break; break;

View File

@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64: case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64: case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;

View File

@ -4,12 +4,17 @@ import pytest
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform from vllm.platforms import current_platform
MAX_MODEL_LEN = 128 MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main") REVISION = os.environ.get("REVISION", "main")
MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME",
"intfloat/multilingual-e5-large")
REVISION_ROBERTA = os.environ.get("REVISION", "main")
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") reason="Xformers backend is not supported on ROCm.")
@ -48,3 +53,42 @@ def test_model_loading_with_params(vllm_runner):
assert model._pooler.normalize assert model._pooler.normalize
# assert output # assert output
assert output assert output
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_roberta_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = model.model.llm_engine.model_config
model_tokenizer = model.model.llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
assert not model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.pooling_norm
# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize
# assert output
assert output

View File

@ -13,10 +13,12 @@ MODELS = [
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5", "BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2", "BAAI/bge-multilingual-gemma2",
"intfloat/multilingual-e5-large",
] ]
ENCODER_ONLY = [ ENCODER_ONLY = [
"BAAI/bge-base-en-v1.5", "BAAI/bge-base-en-v1.5",
"intfloat/multilingual-e5-large",
] ]

View File

@ -10,7 +10,7 @@ class PagedAttention:
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256] return [32, 64, 80, 96, 112, 128, 256]
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(

View File

@ -34,7 +34,7 @@ class PagedAttention:
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256] return [32, 64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(

View File

@ -5,7 +5,7 @@ from torch import nn
from transformers import BertConfig from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -305,14 +305,16 @@ class BertOutput(nn.Module):
class BertModel(nn.Module): class BertModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.embeddings = embedding_class(config)
self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(config, self.encoder = BertEncoder(config,
cache_config, cache_config,
quant_config, quant_config,
@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(vllm_config=vllm_config, self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults( self._pooler = self._build_pooler(pooler_config)
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
def forward( def forward(
self, self,
@ -415,3 +413,16 @@ class BertEmbeddingModel(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights) self.model.load_weights(weights)
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=BertEmbedding)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)

View File

@ -94,6 +94,8 @@ _TEXT_GENERATION_MODELS = {
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {
# [Text-only] # [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"), "BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"), "LlamaModel": ("llama", "LlamaEmbeddingModel"),

View File

@ -0,0 +1,117 @@
from typing import List, Optional
import torch
from torch import nn
from transformers import RobertaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.sequence import IntermediateTensors
class RobertaEmbedding(nn.Module):
def __init__(self, config: RobertaConfig):
super().__init__()
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size,
padding_idx=self.padding_idx)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)
# TODO: figure out if there is a better way
# to make to make position ids start at padding_idx + 1
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
position_ids += self.padding_idx + 1
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings = self.token_type_embeddings(
torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device))
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Verify assumption that position are always a sequence from
# 0 to N. (Actually here we just check 0 and N to simplify).
# This is important to fix the position which are assumed to
# start from padding_idx + 1 instead of 0 in the Roberta models.
assert hasattr(attn_metadata, "seq_lens_tensor")
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
start_pos = torch.cat(
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
cumulative[:-1]))
assert len(torch.nonzero(positions[start_pos])) == 0
end_pos = cumulative - 1
last_tokens = attn_metadata.seq_lens_tensor - 1
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
return super().forward(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)