[FIX] Fix class naming (#1803)

This commit is contained in:
Zhuohan Li 2023-11-28 14:08:01 -08:00 committed by GitHub
parent b943890484
commit 708e6c18b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 17 deletions

View File

@ -12,8 +12,8 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceGroupOutputs, SequenceGroupMetadata, SequenceGroupOutput,
SequenceOutputs, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter
@ -363,7 +363,7 @@ class LLMEngine:
return current_worst_score >= highest_attainable_score return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutputs) -> None: outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs # Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None: if prompt_logprobs is not None:
@ -384,7 +384,7 @@ class LLMEngine:
# Process the child samples for each parent sequence # Process the child samples for each parent sequence
for parent in parent_seqs: for parent in parent_seqs:
child_samples: List[SequenceOutputs] = parent_child_dict[ child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id] parent.seq_id]
if len(child_samples) == 0: if len(child_samples) == 0:
# This parent sequence has no children samples. Remove # This parent sequence has no children samples. Remove

View File

@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutputs, SequenceOutputs) SequenceData, SequenceGroupOutput, SequenceOutput)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@ -641,7 +641,7 @@ def _build_sampler_output(
next_token_ids, next_token_ids,
group_sample_logprobs): group_sample_logprobs):
seq_outputs.append( seq_outputs.append(
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs)) SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append( sampler_output.append(
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs)) SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return sampler_output return sampler_output

View File

@ -352,7 +352,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables self.block_tables = block_tables
class SequenceOutputs: class SequenceOutput:
"""The model output associated with a sequence. """The model output associated with a sequence.
Args: Args:
@ -374,40 +374,40 @@ class SequenceOutputs:
self.logprobs = logprobs self.logprobs = logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, " f"output_token={self.output_token}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutput):
raise NotImplementedError() raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token and self.output_token == other.output_token
and self.logprobs == other.logprobs) and self.logprobs == other.logprobs)
class SequenceGroupOutputs: class SequenceGroupOutput:
"""The model outputs associated with a sequence group.""" """The model output associated with a sequence group."""
def __init__( def __init__(
self, self,
samples: List[SequenceOutputs], samples: List[SequenceOutput],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
) -> None: ) -> None:
self.samples = samples self.samples = samples
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceGroupOutputs(samples={self.samples}, " return (f"SequenceGroupOutput(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})") f"prompt_logprobs={self.prompt_logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceGroupOutputs): if not isinstance(other, SequenceGroupOutput):
raise NotImplementedError() raise NotImplementedError()
return (self.samples == other.samples return (self.samples == other.samples
and self.prompt_logprobs == other.prompt_logprobs) and self.prompt_logprobs == other.prompt_logprobs)
# For each sequence group, we generate a list of SequenceOutputs object, # For each sequence group, we generate a list of SequenceOutput object,
# each of which contains one possible candidate for the next token. # each of which contains one possible candidate for the next token.
SamplerOutput = List[SequenceGroupOutputs] SamplerOutput = List[SequenceGroupOutput]