[Frontend] re-enable multi-modality input in the new beam search implementation (#9427)
Signed-off-by: Qishuai Ferdinandzhong@gmail.com
This commit is contained in:
parent
eae3d48181
commit
ef7865b4f9
@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
image_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
||||
base64_encoded_image: Dict[str, str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_tokens=10,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from vllm.sequence import Logprob
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
@ -16,6 +19,10 @@ class BeamSearchSequence:
|
||||
logprobs: List[Dict[int, Logprob]]
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -6,6 +6,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -59,7 +60,8 @@ class EngineClient(ABC):
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: Union[str, List[int]],
|
||||
prompt: Union[PromptType, List[int]],
|
||||
model_config: ModelConfig,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
@ -69,32 +71,40 @@ class EngineClient(ABC):
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
tokenizer = await self.get_tokenizer(lora_request=None)
|
||||
if isinstance(prompt, str):
|
||||
tokenized_prompt = tokenizer.encode(prompt)
|
||||
prompt_text = prompt
|
||||
else:
|
||||
tokenized_prompt = prompt
|
||||
prompt_text = None
|
||||
tokenized_length = len(tokenized_prompt)
|
||||
tokenizer = await self.get_tokenizer()
|
||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
||||
|
||||
(prompt_text, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(
|
||||
tokenizer.eos_token_id, length_penalty)
|
||||
|
||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature)
|
||||
beam_search_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
all_beams = [
|
||||
BeamSearchSequence(tokens=tokenized_prompt,
|
||||
BeamSearchSequence(tokens=prompt_token_ids,
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
cum_logprob=0)
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
TokensPrompt(prompt_token_ids=beam.tokens,
|
||||
multi_modal_data=beam.multi_modal_data,
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs)
|
||||
for beam in all_beams
|
||||
]
|
||||
|
||||
@ -120,17 +130,31 @@ class EngineClient(ABC):
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
completed.append(new_beam)
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens +
|
||||
[token_id] if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs +
|
||||
[logprobs],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob,
|
||||
finish_reason="stop",
|
||||
stop_reason=tokenizer.eos_token_id))
|
||||
else:
|
||||
new_beams.append(new_beam)
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs +
|
||||
[logprobs],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.
|
||||
mm_processor_kwargs))
|
||||
|
||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
@ -151,16 +175,18 @@ class EngineClient(ABC):
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text,
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.logprobs,
|
||||
) for (i, beam) in enumerate(best_beams)
|
||||
CompletionOutput(text=beam.text,
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.logprobs,
|
||||
finish_reason=beam.finish_reason if
|
||||
beam.finish_reason is not None else "length",
|
||||
stop_reason=beam.stop_reason)
|
||||
for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=tokenized_prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None)
|
||||
|
||||
yield beam_search_output
|
||||
|
||||
@ -308,7 +308,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
@ -606,7 +606,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
@ -236,9 +236,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
result_generator = self.engine_client.beam_search(
|
||||
engine_inputs['prompt_token_ids'],
|
||||
request_id,
|
||||
sampling_params,
|
||||
prompt=engine_inputs,
|
||||
model_config=self.model_config,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
)
|
||||
else:
|
||||
result_generator = self.engine_client.generate(
|
||||
|
||||
@ -150,9 +150,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
request_id_item,
|
||||
sampling_params,
|
||||
prompt={
|
||||
"prompt_token_ids":
|
||||
prompt_inputs["prompt_token_ids"]
|
||||
},
|
||||
model_config=self.model_config,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
|
||||
@ -500,3 +500,4 @@ class BeamSearchParams(
|
||||
ignore_eos: bool = False
|
||||
temperature: float = 0.0
|
||||
length_penalty: float = 1.0
|
||||
include_stop_str_in_output: bool = False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user