[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)

This commit is contained in:
Nick Hill 2024-10-18 08:19:53 +01:00 committed by GitHub
parent 944dd8edaf
commit 1ffc8a7362
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 14 deletions

View File

@ -1,5 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Dict, List, Optional
from vllm.sequence import Logprob
@dataclass @dataclass
@ -11,6 +13,7 @@ class BeamSearchSequence:
""" """
# The tokens includes the prompt. # The tokens includes the prompt.
tokens: List[int] tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0 cum_logprob: float = 0.0
text: Optional[str] = None text: Optional[str] = None
@ -28,7 +31,7 @@ class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]): def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [ self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens) BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
] ]
self.completed: List[BeamSearchSequence] = [] self.completed: List[BeamSearchSequence] = []

View File

@ -59,7 +59,7 @@ class EngineClient(ABC):
async def beam_search( async def beam_search(
self, self,
prompt: Union[PromptType, List[int]], prompt: Union[str, List[int]],
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty = params.length_penalty length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None) tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance( if isinstance(prompt, str):
prompt, list) else tokenizer.encode(prompt) tokenized_prompt = tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt) prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty) tokenizer.eos_token_id, length_penalty)
@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params = SamplingParams(logprobs=2 * beam_width, beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1, max_tokens=1,
temperature=temperature) temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = [] completed = []
for _ in range(max_tokens): for _ in range(max_tokens):
@ -114,6 +122,7 @@ class EngineClient(ABC):
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence( new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob) logprob_obj.logprob)
@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
for beam in best_beams: for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) beam.text = tokenizer.decode(beam.tokens[tokenized_length:])
beam_search_output = RequestOutput( beam_search_output = RequestOutput(
request_id=request_id, request_id=request_id,
prompt=prompt, prompt=prompt_text,
outputs=[ outputs=[
CompletionOutput( CompletionOutput(
text=beam.text, text=beam.text,
cumulative_logprob=beam.cum_logprob, cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens, token_ids=beam.tokens[tokenized_length:],
index=i, index=i,
logprobs=beam.cum_logprob, logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams) ) for (i, beam) in enumerate(best_beams)
], ],
finished=True, finished=True,
prompt_token_ids=tokenizedPrompt, prompt_token_ids=tokenized_prompt,
prompt_logprobs=None) prompt_logprobs=None)
yield beam_search_output yield beam_search_output

View File

@ -433,6 +433,7 @@ class LLM:
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence( new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob) logprob_obj.logprob)

View File

@ -4,7 +4,6 @@ from typing import List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@ -93,7 +92,7 @@ class RequestOutput:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
prompt: Optional[PromptType], prompt: Optional[str],
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],