From 5a6c81b0511da333b1fabf5ad612eb7874d5e88e Mon Sep 17 00:00:00 2001 From: Rex Date: Sun, 4 Feb 2024 14:32:42 -0800 Subject: [PATCH] Remove eos tokens from output by default (#2611) --- vllm/engine/llm_engine.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 92568450..02c673c9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -908,13 +908,13 @@ class LLMEngine: """Stop the finished sequences.""" for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): - if not sampling_params.include_stop_str_in_output: - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED return if seq.get_last_token_id() in sampling_params.stop_token_ids: + stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( + seq.get_last_token_id()) + self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED return @@ -934,6 +934,14 @@ class LLMEngine: seq.status = SequenceStatus.FINISHED_STOPPED return + def _finalize_sequence(self, seq: Sequence, + sampling_params: SamplingParams, + stop_string: str) -> None: + if not sampling_params.include_stop_str_in_output and stop_string: + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_string)] + def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self._run_workers(