[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)
This commit is contained in:
parent
944dd8edaf
commit
1ffc8a7362
@ -1,5 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from vllm.sequence import Logprob
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -11,6 +13,7 @@ class BeamSearchSequence:
|
|||||||
"""
|
"""
|
||||||
# The tokens includes the prompt.
|
# The tokens includes the prompt.
|
||||||
tokens: List[int]
|
tokens: List[int]
|
||||||
|
logprobs: List[Dict[int, Logprob]]
|
||||||
cum_logprob: float = 0.0
|
cum_logprob: float = 0.0
|
||||||
text: Optional[str] = None
|
text: Optional[str] = None
|
||||||
|
|
||||||
@ -28,7 +31,7 @@ class BeamSearchInstance:
|
|||||||
|
|
||||||
def __init__(self, prompt_tokens: List[int]):
|
def __init__(self, prompt_tokens: List[int]):
|
||||||
self.beams: List[BeamSearchSequence] = [
|
self.beams: List[BeamSearchSequence] = [
|
||||||
BeamSearchSequence(tokens=prompt_tokens)
|
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
|
||||||
]
|
]
|
||||||
self.completed: List[BeamSearchSequence] = []
|
self.completed: List[BeamSearchSequence] = []
|
||||||
|
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class EngineClient(ABC):
|
|||||||
|
|
||||||
async def beam_search(
|
async def beam_search(
|
||||||
self,
|
self,
|
||||||
prompt: Union[PromptType, List[int]],
|
prompt: Union[str, List[int]],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
params: BeamSearchParams,
|
params: BeamSearchParams,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
@ -71,9 +71,13 @@ class EngineClient(ABC):
|
|||||||
length_penalty = params.length_penalty
|
length_penalty = params.length_penalty
|
||||||
|
|
||||||
tokenizer = await self.get_tokenizer(lora_request=None)
|
tokenizer = await self.get_tokenizer(lora_request=None)
|
||||||
tokenizedPrompt = prompt if isinstance(
|
if isinstance(prompt, str):
|
||||||
prompt, list) else tokenizer.encode(prompt)
|
tokenized_prompt = tokenizer.encode(prompt)
|
||||||
tokenizedLength = len(tokenizedPrompt)
|
prompt_text = prompt
|
||||||
|
else:
|
||||||
|
tokenized_prompt = prompt
|
||||||
|
prompt_text = None
|
||||||
|
tokenized_length = len(tokenized_prompt)
|
||||||
|
|
||||||
sort_beams_key = create_sort_beams_key_function(
|
sort_beams_key = create_sort_beams_key_function(
|
||||||
tokenizer.eos_token_id, length_penalty)
|
tokenizer.eos_token_id, length_penalty)
|
||||||
@ -81,7 +85,11 @@ class EngineClient(ABC):
|
|||||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||||
max_tokens=1,
|
max_tokens=1,
|
||||||
temperature=temperature)
|
temperature=temperature)
|
||||||
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
all_beams = [
|
||||||
|
BeamSearchSequence(tokens=tokenized_prompt,
|
||||||
|
logprobs=[],
|
||||||
|
cum_logprob=0)
|
||||||
|
]
|
||||||
completed = []
|
completed = []
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
@ -114,6 +122,7 @@ class EngineClient(ABC):
|
|||||||
for token_id, logprob_obj in logprobs.items():
|
for token_id, logprob_obj in logprobs.items():
|
||||||
new_beam = BeamSearchSequence(
|
new_beam = BeamSearchSequence(
|
||||||
tokens=current_beam.tokens + [token_id],
|
tokens=current_beam.tokens + [token_id],
|
||||||
|
logprobs=current_beam.logprobs + [logprobs],
|
||||||
cum_logprob=current_beam.cum_logprob +
|
cum_logprob=current_beam.cum_logprob +
|
||||||
logprob_obj.logprob)
|
logprob_obj.logprob)
|
||||||
|
|
||||||
@ -131,22 +140,22 @@ class EngineClient(ABC):
|
|||||||
best_beams = sorted_completed[:beam_width]
|
best_beams = sorted_completed[:beam_width]
|
||||||
|
|
||||||
for beam in best_beams:
|
for beam in best_beams:
|
||||||
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
beam.text = tokenizer.decode(beam.tokens[tokenized_length:])
|
||||||
|
|
||||||
beam_search_output = RequestOutput(
|
beam_search_output = RequestOutput(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
prompt=prompt,
|
prompt=prompt_text,
|
||||||
outputs=[
|
outputs=[
|
||||||
CompletionOutput(
|
CompletionOutput(
|
||||||
text=beam.text,
|
text=beam.text,
|
||||||
cumulative_logprob=beam.cum_logprob,
|
cumulative_logprob=beam.cum_logprob,
|
||||||
token_ids=beam.tokens,
|
token_ids=beam.tokens[tokenized_length:],
|
||||||
index=i,
|
index=i,
|
||||||
logprobs=beam.cum_logprob,
|
logprobs=beam.logprobs,
|
||||||
) for (i, beam) in enumerate(best_beams)
|
) for (i, beam) in enumerate(best_beams)
|
||||||
],
|
],
|
||||||
finished=True,
|
finished=True,
|
||||||
prompt_token_ids=tokenizedPrompt,
|
prompt_token_ids=tokenized_prompt,
|
||||||
prompt_logprobs=None)
|
prompt_logprobs=None)
|
||||||
|
|
||||||
yield beam_search_output
|
yield beam_search_output
|
||||||
|
|||||||
@ -433,6 +433,7 @@ class LLM:
|
|||||||
for token_id, logprob_obj in logprobs.items():
|
for token_id, logprob_obj in logprobs.items():
|
||||||
new_beam = BeamSearchSequence(
|
new_beam = BeamSearchSequence(
|
||||||
tokens=current_beam.tokens + [token_id],
|
tokens=current_beam.tokens + [token_id],
|
||||||
|
logprobs=current_beam.logprobs + [logprobs],
|
||||||
cum_logprob=current_beam.cum_logprob +
|
cum_logprob=current_beam.cum_logprob +
|
||||||
logprob_obj.logprob)
|
logprob_obj.logprob)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import List, Optional
|
|||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from vllm.inputs import PromptType
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||||
@ -93,7 +92,7 @@ class RequestOutput:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: Optional[PromptType],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
prompt_logprobs: Optional[PromptLogprobs],
|
prompt_logprobs: Optional[PromptLogprobs],
|
||||||
outputs: List[CompletionOutput],
|
outputs: List[CompletionOutput],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user