[Frontend] re-enable multi-modality input in the new beam search implementation (#9427)

Signed-off-by: Qishuai Ferdinandzhong@gmail.com
This commit is contained in:
Zhong Qishuai 2024-10-29 19:49:47 +08:00 committed by GitHub
parent eae3d48181
commit ef7865b4f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 150 additions and 40 deletions

View File

@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
model_name: str,
image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)

View File

@ -1,8 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from vllm.sequence import Logprob from vllm.sequence import Logprob
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
@dataclass @dataclass
class BeamSearchSequence: class BeamSearchSequence:
@ -16,6 +19,10 @@ class BeamSearchSequence:
logprobs: List[Dict[int, Logprob]] logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0 cum_logprob: float = 0.0
text: Optional[str] = None text: Optional[str] = None
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
multi_modal_data: Optional["MultiModalDataDict"] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
@dataclass @dataclass

View File

@ -6,6 +6,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -59,7 +60,8 @@ class EngineClient(ABC):
async def beam_search( async def beam_search(
self, self,
prompt: Union[str, List[int]], prompt: Union[PromptType, List[int]],
model_config: ModelConfig,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@ -69,32 +71,40 @@ class EngineClient(ABC):
ignore_eos = params.ignore_eos ignore_eos = params.ignore_eos
temperature = params.temperature temperature = params.temperature
length_penalty = params.length_penalty length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
tokenizer = await self.get_tokenizer(lora_request=None) tokenizer = await self.get_tokenizer()
if isinstance(prompt, str): input_preprocessor = InputPreprocessor(model_config, tokenizer)
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt (prompt_text, prompt_token_ids, multi_modal_data,
else: mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
tokenized_prompt = prompt prompt,
prompt_text = None request_id=request_id,
tokenized_length = len(tokenized_prompt) )
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty) tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(logprobs=2 * beam_width, beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1, max_tokens=1,
temperature=temperature) temperature=temperature,
)
all_beams = [ all_beams = [
BeamSearchSequence(tokens=tokenized_prompt, BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[], logprobs=[],
cum_logprob=0) multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
] ]
completed = [] completed = []
for _ in range(max_tokens): for _ in range(max_tokens):
prompts_batch = [ prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens) TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams for beam in all_beams
] ]
@ -120,17 +130,31 @@ class EngineClient(ABC):
if result.outputs[0].logprobs is not None: if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0] logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \ if token_id == tokenizer.eos_token_id and \
not ignore_eos: not ignore_eos:
completed.append(new_beam) completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
finish_reason="stop",
stop_reason=tokenizer.eos_token_id))
else: else:
new_beams.append(new_beam) new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs))
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width] all_beams = sorted_beams[:beam_width]
@ -151,16 +175,18 @@ class EngineClient(ABC):
request_id=request_id, request_id=request_id,
prompt=prompt_text, prompt=prompt_text,
outputs=[ outputs=[
CompletionOutput( CompletionOutput(text=beam.text,
text=beam.text,
cumulative_logprob=beam.cum_logprob, cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:], token_ids=beam.tokens[tokenized_length:],
index=i, index=i,
logprobs=beam.logprobs, logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams) finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
for (i, beam) in enumerate(best_beams)
], ],
finished=True, finished=True,
prompt_token_ids=tokenized_prompt, prompt_token_ids=prompt_token_ids,
prompt_logprobs=None) prompt_logprobs=None)
yield beam_search_output yield beam_search_output

View File

@ -308,7 +308,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
@ -606,7 +606,7 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens

View File

@ -236,9 +236,10 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search( result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'], prompt=engine_inputs,
request_id, model_config=self.model_config,
sampling_params, request_id=request_id,
params=sampling_params,
) )
else: else:
result_generator = self.engine_client.generate( result_generator = self.engine_client.generate(

View File

@ -150,9 +150,13 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search( generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], prompt={
request_id_item, "prompt_token_ids":
sampling_params, prompt_inputs["prompt_token_ids"]
},
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
) )
else: else:
generator = self.engine_client.generate( generator = self.engine_client.generate(

View File

@ -500,3 +500,4 @@ class BeamSearchParams(
ignore_eos: bool = False ignore_eos: bool = False
temperature: float = 0.0 temperature: float = 0.0
length_penalty: float = 1.0 length_penalty: float = 1.0
include_stop_str_in_output: bool = False