[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)
This commit is contained in:
parent
944dd8edaf
commit
1ffc8a7362
@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.sequence import Logprob
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -11,6 +13,7 @@ class BeamSearchSequence:
|
||||
"""
|
||||
# The tokens includes the prompt.
|
||||
tokens: List[int]
|
||||
logprobs: List[Dict[int, Logprob]]
|
||||
cum_logprob: float = 0.0
|
||||
text: Optional[str] = None
|
||||
|
||||
@ -28,7 +31,7 @@ class BeamSearchInstance:
|
||||
|
||||
def __init__(self, prompt_tokens: List[int]):
|
||||
self.beams: List[BeamSearchSequence] = [
|
||||
BeamSearchSequence(tokens=prompt_tokens)
|
||||
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
|
||||
]
|
||||
self.completed: List[BeamSearchSequence] = []
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ class EngineClient(ABC):
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: Union[PromptType, List[int]],
|
||||
prompt: Union[str, List[int]],
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
@ -71,9 +71,13 @@ class EngineClient(ABC):
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
tokenizer = await self.get_tokenizer(lora_request=None)
|
||||
tokenizedPrompt = prompt if isinstance(
|
||||
prompt, list) else tokenizer.encode(prompt)
|
||||
tokenizedLength = len(tokenizedPrompt)
|
||||
if isinstance(prompt, str):
|
||||
tokenized_prompt = tokenizer.encode(prompt)
|
||||
prompt_text = prompt
|
||||
else:
|
||||
tokenized_prompt = prompt
|
||||
prompt_text = None
|
||||
tokenized_length = len(tokenized_prompt)
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(
|
||||
tokenizer.eos_token_id, length_penalty)
|
||||
@ -81,7 +85,11 @@ class EngineClient(ABC):
|
||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature)
|
||||
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
||||
all_beams = [
|
||||
BeamSearchSequence(tokens=tokenized_prompt,
|
||||
logprobs=[],
|
||||
cum_logprob=0)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
@ -114,6 +122,7 @@ class EngineClient(ABC):
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
@ -131,22 +140,22 @@ class EngineClient(ABC):
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
||||
beam.text = tokenizer.decode(beam.tokens[tokenized_length:])
|
||||
|
||||
beam_search_output = RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text,
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.cum_logprob,
|
||||
logprobs=beam.logprobs,
|
||||
) for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=tokenizedPrompt,
|
||||
prompt_token_ids=tokenized_prompt,
|
||||
prompt_logprobs=None)
|
||||
|
||||
yield beam_search_output
|
||||
|
||||
@ -433,6 +433,7 @@ class LLM:
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
@ -93,7 +92,7 @@ class RequestOutput:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[PromptType],
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
outputs: List[CompletionOutput],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user