diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 0d1c3280..3f586fe1 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -322,9 +322,15 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, temperature=0.0, stream=True) chunks = [] + finish_reason_count = 0 async for chunk in stream: chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text assert chunk.usage == single_usage assert "".join(chunks) == single_output @@ -363,13 +369,19 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, stream=True, ) chunks = [] + finish_reason_count = 0 async for chunk in stream: delta = chunk.choices[0].delta if delta.role: assert delta.role == "assistant" if delta.content: chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 assert chunk.choices[0].finish_reason == stop_reason + assert delta.content assert "".join(chunks) == output diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 33c79734..9d5319c8 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -266,6 +266,16 @@ class OpenAIServingCompletion(OpenAIServing): previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) finish_reason = output.finish_reason + if output.finish_reason is not None: # return final usage + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + else: + final_usage = None response_json = CompletionStreamResponse( id=request_id, created=created_time, @@ -277,34 +287,10 @@ class OpenAIServingCompletion(OpenAIServing): logprobs=logprobs, finish_reason=finish_reason, ) - ]).model_dump_json() + ], + usage=final_usage, + ).model_dump_json(exclude_unset=True) yield f"data: {response_json}\n\n" - - if output.finish_reason is not None: # return final usage - logprobs = LogProbs( - ) if request.logprobs is not None else None - prompt_tokens = len(res.prompt_token_ids) - completion_tokens = len(output.token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - response_json = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[ - CompletionResponseStreamChoice( - index=i, - text="", - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - ], - usage=final_usage, - ).model_dump_json() - yield f"data: {response_json}\n\n" except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e))