[FIX] Fix class naming (#1803)
This commit is contained in:
parent
b943890484
commit
708e6c18b0
@ -12,8 +12,8 @@ from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceGroupOutputs,
|
||||
SequenceOutputs, SequenceStatus)
|
||||
SequenceGroupMetadata, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
@ -363,7 +363,7 @@ class LLMEngine:
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutputs) -> None:
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None:
|
||||
@ -384,7 +384,7 @@ class LLMEngine:
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutputs] = parent_child_dict[
|
||||
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||
SequenceData, SequenceGroupOutputs, SequenceOutputs)
|
||||
SequenceData, SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -641,7 +641,7 @@ def _build_sampler_output(
|
||||
next_token_ids,
|
||||
group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||
sampler_output.append(
|
||||
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
|
||||
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||
return sampler_output
|
||||
|
||||
@ -352,7 +352,7 @@ class SequenceGroupMetadata:
|
||||
self.block_tables = block_tables
|
||||
|
||||
|
||||
class SequenceOutputs:
|
||||
class SequenceOutput:
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
@ -374,40 +374,40 @@ class SequenceOutputs:
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
|
||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}, "
|
||||
f"logprobs={self.logprobs})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
if not isinstance(other, SequenceOutput):
|
||||
raise NotImplementedError()
|
||||
return (self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
|
||||
class SequenceGroupOutputs:
|
||||
"""The model outputs associated with a sequence group."""
|
||||
class SequenceGroupOutput:
|
||||
"""The model output associated with a sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: List[SequenceOutputs],
|
||||
samples: List[SequenceOutput],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroupOutputs(samples={self.samples}, "
|
||||
return (f"SequenceGroupOutput(samples={self.samples}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceGroupOutputs):
|
||||
if not isinstance(other, SequenceGroupOutput):
|
||||
raise NotImplementedError()
|
||||
return (self.samples == other.samples
|
||||
and self.prompt_logprobs == other.prompt_logprobs)
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||
# For each sequence group, we generate a list of SequenceOutput object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
SamplerOutput = List[SequenceGroupOutputs]
|
||||
SamplerOutput = List[SequenceGroupOutput]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user