[Frontend] add json_schema support from OpenAI protocol (#7654)

This commit is contained in:
Tyler Rockwood 2024-08-24 01:07:24 -05:00 committed by GitHub
parent 8da48e4d95
commit d81abefd2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 2 deletions

View File

@ -837,6 +837,39 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded
@pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "foo_test",
"schema": {
"type": "object",
"properties": {
"result": {
"type": "integer"
},
},
},
}
})
content = resp.choices[0].message.content
assert content is not None
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
@pytest.mark.asyncio
async def test_extra_fields(client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info:

View File

@ -85,9 +85,19 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens: Optional[int] = 0
class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str
description: Optional[str] = None
# schema is the field in openai but that causes conflicts with pydantic so
# instead use json_schema with an alias
json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
strict: Optional[bool] = None
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
# type must be "json_schema", "json_object" or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StreamOptions(OpenAIBaseModel):

View File

@ -49,6 +49,13 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else:
return None

View File

@ -127,6 +127,13 @@ def _get_guide_and_mode(
and request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else:
return None, None