[FIX] Fix class naming (#1803)
This commit is contained in:
parent
b943890484
commit
708e6c18b0
@ -12,8 +12,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||||
SequenceGroupMetadata, SequenceGroupOutputs,
|
SequenceGroupMetadata, SequenceGroupOutput,
|
||||||
SequenceOutputs, SequenceStatus)
|
SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||||
get_tokenizer)
|
get_tokenizer)
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
@ -363,7 +363,7 @@ class LLMEngine:
|
|||||||
return current_worst_score >= highest_attainable_score
|
return current_worst_score >= highest_attainable_score
|
||||||
|
|
||||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||||
outputs: SequenceGroupOutputs) -> None:
|
outputs: SequenceGroupOutput) -> None:
|
||||||
# Process prompt logprobs
|
# Process prompt logprobs
|
||||||
prompt_logprobs = outputs.prompt_logprobs
|
prompt_logprobs = outputs.prompt_logprobs
|
||||||
if prompt_logprobs is not None:
|
if prompt_logprobs is not None:
|
||||||
@ -384,7 +384,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Process the child samples for each parent sequence
|
# Process the child samples for each parent sequence
|
||||||
for parent in parent_seqs:
|
for parent in parent_seqs:
|
||||||
child_samples: List[SequenceOutputs] = parent_child_dict[
|
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||||
parent.seq_id]
|
parent.seq_id]
|
||||||
if len(child_samples) == 0:
|
if len(child_samples) == 0:
|
||||||
# This parent sequence has no children samples. Remove
|
# This parent sequence has no children samples. Remove
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
|||||||
tensor_model_parallel_all_gather)
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||||
SequenceData, SequenceGroupOutputs, SequenceOutputs)
|
SequenceData, SequenceGroupOutput, SequenceOutput)
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@ -641,7 +641,7 @@ def _build_sampler_output(
|
|||||||
next_token_ids,
|
next_token_ids,
|
||||||
group_sample_logprobs):
|
group_sample_logprobs):
|
||||||
seq_outputs.append(
|
seq_outputs.append(
|
||||||
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
|
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||||
sampler_output.append(
|
sampler_output.append(
|
||||||
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
|
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|||||||
@ -352,7 +352,7 @@ class SequenceGroupMetadata:
|
|||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
|
||||||
|
|
||||||
class SequenceOutputs:
|
class SequenceOutput:
|
||||||
"""The model output associated with a sequence.
|
"""The model output associated with a sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -374,40 +374,40 @@ class SequenceOutputs:
|
|||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
|
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||||
f"output_token={self.output_token}, "
|
f"output_token={self.output_token}, "
|
||||||
f"logprobs={self.logprobs})")
|
f"logprobs={self.logprobs})")
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, SequenceOutputs):
|
if not isinstance(other, SequenceOutput):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return (self.parent_seq_id == other.parent_seq_id
|
return (self.parent_seq_id == other.parent_seq_id
|
||||||
and self.output_token == other.output_token
|
and self.output_token == other.output_token
|
||||||
and self.logprobs == other.logprobs)
|
and self.logprobs == other.logprobs)
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupOutputs:
|
class SequenceGroupOutput:
|
||||||
"""The model outputs associated with a sequence group."""
|
"""The model output associated with a sequence group."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
samples: List[SequenceOutputs],
|
samples: List[SequenceOutput],
|
||||||
prompt_logprobs: Optional[PromptLogprobs],
|
prompt_logprobs: Optional[PromptLogprobs],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceGroupOutputs(samples={self.samples}, "
|
return (f"SequenceGroupOutput(samples={self.samples}, "
|
||||||
f"prompt_logprobs={self.prompt_logprobs})")
|
f"prompt_logprobs={self.prompt_logprobs})")
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, SequenceGroupOutputs):
|
if not isinstance(other, SequenceGroupOutput):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return (self.samples == other.samples
|
return (self.samples == other.samples
|
||||||
and self.prompt_logprobs == other.prompt_logprobs)
|
and self.prompt_logprobs == other.prompt_logprobs)
|
||||||
|
|
||||||
|
|
||||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
# For each sequence group, we generate a list of SequenceOutput object,
|
||||||
# each of which contains one possible candidate for the next token.
|
# each of which contains one possible candidate for the next token.
|
||||||
SamplerOutput = List[SequenceGroupOutputs]
|
SamplerOutput = List[SequenceGroupOutput]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user