[FIX] Fix class naming (#1803)

This commit is contained in:
Zhuohan Li 2023-11-28 14:08:01 -08:00 committed by GitHub
parent b943890484
commit 708e6c18b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 17 deletions

View File

@ -12,8 +12,8 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceGroupOutputs,
SequenceOutputs, SequenceStatus)
SequenceGroupMetadata, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
@ -363,7 +363,7 @@ class LLMEngine:
return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutputs) -> None:
outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
@ -384,7 +384,7 @@ class LLMEngine:
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutputs] = parent_child_dict[
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove

View File

@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutputs, SequenceOutputs)
SequenceData, SequenceGroupOutput, SequenceOutput)
_SAMPLING_EPS = 1e-5
@ -641,7 +641,7 @@ def _build_sampler_output(
next_token_ids,
group_sample_logprobs):
seq_outputs.append(
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return sampler_output

View File

@ -352,7 +352,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables
class SequenceOutputs:
class SequenceOutput:
"""The model output associated with a sequence.
Args:
@ -374,40 +374,40 @@ class SequenceOutputs:
self.logprobs = logprobs
def __repr__(self) -> str:
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
if not isinstance(other, SequenceOutput):
raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
class SequenceGroupOutputs:
"""The model outputs associated with a sequence group."""
class SequenceGroupOutput:
"""The model output associated with a sequence group."""
def __init__(
self,
samples: List[SequenceOutputs],
samples: List[SequenceOutput],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:
return (f"SequenceGroupOutputs(samples={self.samples}, "
return (f"SequenceGroupOutput(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceGroupOutputs):
if not isinstance(other, SequenceGroupOutput):
raise NotImplementedError()
return (self.samples == other.samples
and self.prompt_logprobs == other.prompt_logprobs)
# For each sequence group, we generate a list of SequenceOutputs object,
# For each sequence group, we generate a list of SequenceOutput object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[SequenceGroupOutputs]
SamplerOutput = List[SequenceGroupOutput]