diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index afcb0f44..ce5bf3d5 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -837,6 +837,39 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +@pytest.mark.asyncio +async def test_response_format_json_schema(client: openai.AsyncOpenAI): + for _ in range(2): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": ('what is 1+1? please respond with a JSON object, ' + 'the format is {"result": 2}') + }], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "foo_test", + "schema": { + "type": "object", + "properties": { + "result": { + "type": "integer" + }, + }, + }, + } + }) + + content = resp.choices[0].message.content + assert content is not None + + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded + + @pytest.mark.asyncio async def test_extra_fields(client: openai.AsyncOpenAI): with pytest.raises(BadRequestError) as exc_info: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c46f5cf8..0954b815 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -85,9 +85,19 @@ class UsageInfo(OpenAIBaseModel): completion_tokens: Optional[int] = 0 +class JsonSchemaResponseFormat(OpenAIBaseModel): + name: str + description: Optional[str] = None + # schema is the field in openai but that causes conflicts with pydantic so + # instead use json_schema with an alias + json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema') + strict: Optional[bool] = None + + class ResponseFormat(OpenAIBaseModel): - # type must be "json_object" or "text" - type: Literal["text", "json_object"] + # type must be "json_schema", "json_object" or "text" + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None class StreamOptions(OpenAIBaseModel): diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index b2188c9c..8de811a6 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -49,6 +49,13 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( and request.response_format.type == "json_object"): character_level_parser = JsonSchemaParser( None) # None means any json object + elif (request.response_format is not None + and request.response_format.type == "json_schema" + and request.response_format.json_schema is not None + and request.response_format.json_schema.json_schema is not None): + schema = _normalize_json_schema_object( + request.response_format.json_schema.json_schema) + character_level_parser = JsonSchemaParser(schema) else: return None diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bc62224d..bfc658ef 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -127,6 +127,13 @@ def _get_guide_and_mode( and request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR + elif (not isinstance(request, GuidedDecodingRequest) + and request.response_format is not None + and request.response_format.type == "json_schema" + and request.response_format.json_schema is not None + and request.response_format.json_schema.json_schema is not None): + json = json_dumps(request.response_format.json_schema.json_schema) + return json, GuidedDecodingMode.JSON else: return None, None