From 5b59fe0f08c16e56813f2dad442d44cab222668b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 19 Oct 2024 17:05:02 -0700 Subject: [PATCH] [Bugfix] Pass json-schema to GuidedDecodingParams and make test stronger (#9530) --- tests/entrypoints/openai/test_chat.py | 22 ++++++++++++++++++---- vllm/entrypoints/openai/protocol.py | 16 +++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 3af0032f..a2974760 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -851,14 +851,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_response_format_json_schema(client: openai.AsyncOpenAI): + prompt = 'what is 1+1? The format is "result": 2' + # Check that this prompt cannot lead to a valid JSON without json_schema for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, messages=[{ - "role": - "user", - "content": ('what is 1+1? please respond with a JSON object, ' - 'the format is {"result": 2}') + "role": "user", + "content": prompt + }], + ) + content = resp.choices[0].message.content + assert content is not None + with pytest.raises((json.JSONDecodeError, AssertionError)): + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded + + for _ in range(2): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": prompt }], response_format={ "type": "json_schema", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6f1135f8..06114339 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -314,9 +314,15 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs = self.top_logprobs guided_json_object = None - if (self.response_format is not None - and self.response_format.type == "json_object"): - guided_json_object = True + if self.response_format is not None: + if self.response_format.type == "json_object": + guided_json_object = True + elif self.response_format.type == "json_schema": + json_schema = self.response_format.json_schema + assert json_schema is not None + self.guided_json = json_schema.json_schema + if self.guided_decoding_backend is None: + self.guided_decoding_backend = "lm-format-enforcer" guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json, @@ -537,8 +543,8 @@ class CompletionRequest(OpenAIBaseModel): default=None, description= ("Similar to chat completion, this parameter specifies the format of " - "output. Only {'type': 'json_object'} or {'type': 'text' } is " - "supported."), + "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or " + "{'type': 'text' } is supported."), ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None,