diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 92568450..02c673c9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -908,13 +908,13 @@ class LLMEngine: """Stop the finished sequences.""" for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): - if not sampling_params.include_stop_str_in_output: - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED return if seq.get_last_token_id() in sampling_params.stop_token_ids: + stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( + seq.get_last_token_id()) + self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED return @@ -934,6 +934,14 @@ class LLMEngine: seq.status = SequenceStatus.FINISHED_STOPPED return + def _finalize_sequence(self, seq: Sequence, + sampling_params: SamplingParams, + stop_string: str) -> None: + if not sampling_params.include_stop_str_in_output and stop_string: + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_string)] + def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self._run_workers(