[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)

This commit is contained in:
Nick Hill 2024-10-18 08:19:53 +01:00 committed by GitHub
parent 944dd8edaf
commit 1ffc8a7362
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 14 deletions

View File

@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional
from vllm.sequence import Logprob
@dataclass
@ -11,6 +13,7 @@ class BeamSearchSequence:
"""
# The tokens includes the prompt.
tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
@ -28,7 +31,7 @@ class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
]
self.completed: List[BeamSearchSequence] = []

View File

@ -59,7 +59,7 @@ class EngineClient(ABC):
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
prompt: Union[str, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
if isinstance(prompt, str):
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = []
for _ in range(max_tokens):
@ -114,6 +122,7 @@ class EngineClient(ABC):
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam.text = tokenizer.decode(beam.tokens[tokenized_length:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.cum_logprob,
logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_token_ids=tokenized_prompt,
prompt_logprobs=None)
yield beam_search_output

View File

@ -433,6 +433,7 @@ class LLM:
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

View File

@ -4,7 +4,6 @@ from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@ -93,7 +92,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: Optional[PromptType],
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],