Remove eos tokens from output by default (#2611)
This commit is contained in:
parent
51cd22ce56
commit
5a6c81b051
@ -908,13 +908,13 @@ class LLMEngine:
|
|||||||
"""Stop the finished sequences."""
|
"""Stop the finished sequences."""
|
||||||
for stop_str in sampling_params.stop:
|
for stop_str in sampling_params.stop:
|
||||||
if seq.output_text.endswith(stop_str):
|
if seq.output_text.endswith(stop_str):
|
||||||
if not sampling_params.include_stop_str_in_output:
|
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
# Truncate the output text so that the stop string is
|
|
||||||
# not included in the output.
|
|
||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||||
|
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
||||||
|
seq.get_last_token_id())
|
||||||
|
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -934,6 +934,14 @@ class LLMEngine:
|
|||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _finalize_sequence(self, seq: Sequence,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
stop_string: str) -> None:
|
||||||
|
if not sampling_params.include_stop_str_in_output and stop_string:
|
||||||
|
# Truncate the output text so that the stop string is
|
||||||
|
# not included in the output.
|
||||||
|
seq.output_text = seq.output_text[:-len(stop_string)]
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||||
return self._run_workers(
|
return self._run_workers(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user