[core][misc] simply output processing with shortcut code path (#7117)
This commit is contained in:
parent
9fadc7b7a0
commit
83c644fe7e
@ -81,6 +81,29 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
|
|
||||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||||
outputs: SequenceGroupOutput) -> None:
|
outputs: SequenceGroupOutput) -> None:
|
||||||
|
sampling_params = seq_group.sampling_params
|
||||||
|
if sampling_params.n == 1 and not sampling_params.use_beam_search:
|
||||||
|
# only have one output sample
|
||||||
|
sample = outputs.samples[0]
|
||||||
|
# only have one sequence
|
||||||
|
seq = seq_group.seqs[0]
|
||||||
|
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||||
|
if sampling_params.detokenize and self.detokenizer:
|
||||||
|
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||||
|
seq, sampling_params)
|
||||||
|
else:
|
||||||
|
new_char_count = 0
|
||||||
|
self.stop_checker.maybe_stop_sequence(
|
||||||
|
seq,
|
||||||
|
new_char_count,
|
||||||
|
sampling_params,
|
||||||
|
lora_req=seq_group.lora_request,
|
||||||
|
)
|
||||||
|
if seq.is_finished():
|
||||||
|
for scheduler in self.scheduler:
|
||||||
|
scheduler.free_seq(seq)
|
||||||
|
return
|
||||||
|
|
||||||
# Process samples
|
# Process samples
|
||||||
samples = outputs.samples
|
samples = outputs.samples
|
||||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||||
@ -127,20 +150,20 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
child_seqs.append((parent, parent))
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
for seq, _ in child_seqs:
|
for seq, _ in child_seqs:
|
||||||
if seq_group.sampling_params.detokenize and self.detokenizer:
|
if sampling_params.detokenize and self.detokenizer:
|
||||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||||
seq, seq_group.sampling_params)
|
seq, sampling_params)
|
||||||
else:
|
else:
|
||||||
new_char_count = 0
|
new_char_count = 0
|
||||||
self.stop_checker.maybe_stop_sequence(
|
self.stop_checker.maybe_stop_sequence(
|
||||||
seq,
|
seq,
|
||||||
new_char_count,
|
new_char_count,
|
||||||
seq_group.sampling_params,
|
sampling_params,
|
||||||
lora_req=seq_group.lora_request,
|
lora_req=seq_group.lora_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-beam search case
|
# Non-beam search case
|
||||||
if not seq_group.sampling_params.use_beam_search:
|
if not sampling_params.use_beam_search:
|
||||||
# For newly created child sequences, add them to the sequence group
|
# For newly created child sequences, add them to the sequence group
|
||||||
# and fork them in block manager if they are not finished.
|
# and fork them in block manager if they are not finished.
|
||||||
for seq, parent in child_seqs:
|
for seq, parent in child_seqs:
|
||||||
@ -164,8 +187,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
# Select the child sequences to keep in the sequence group.
|
# Select the child sequences to keep in the sequence group.
|
||||||
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
|
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
|
||||||
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
|
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
|
||||||
beam_width = seq_group.sampling_params.best_of
|
beam_width = sampling_params.best_of
|
||||||
length_penalty = seq_group.sampling_params.length_penalty
|
length_penalty = sampling_params.length_penalty
|
||||||
|
|
||||||
# Select the newly finished sequences with the highest scores
|
# Select the newly finished sequences with the highest scores
|
||||||
# to replace existing finished sequences.
|
# to replace existing finished sequences.
|
||||||
@ -219,8 +242,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
best_running_seq = running_child_seqs[0][0]
|
best_running_seq = running_child_seqs[0][0]
|
||||||
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||||
stop_beam_search = self._check_beam_search_early_stopping(
|
stop_beam_search = self._check_beam_search_early_stopping(
|
||||||
seq_group.sampling_params.early_stopping,
|
sampling_params.early_stopping, sampling_params,
|
||||||
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
best_running_seq, current_worst_seq)
|
||||||
|
|
||||||
if stop_beam_search:
|
if stop_beam_search:
|
||||||
# Stop the beam search and remove all the running sequences from
|
# Stop the beam search and remove all the running sequences from
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user