Support arbitrary json_object in OpenAI and Context Free Grammar (#3211)

This commit is contained in:
Simon Mo 2024-03-16 13:35:27 -07:00 committed by GitHub
parent 8e67598aa6
commit 120157fd2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 177 additions and 50 deletions

View File

@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={"type": "json_object"})
content = resp.choices[0].message.content
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""
completion = await client.completions.create(
model=MODEL_NAME,
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_grammar=simple_sql_grammar))
content = completion.choices[0].text
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(simple_sql_grammar)
parser.parse(content)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
assert content.strip() == ground_truth
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Dict[str, str]] messages: List[Dict[str, str]]
@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None
def to_sampling_params(self) -> SamplingParams: def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs: if self.logprobs and not self.top_logprobs:
@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
response_format: Optional[ResponseFormat] = None
def to_sampling_params(self): def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0

View File

@ -6,20 +6,51 @@ from functools import lru_cache
from json import dumps as json_dumps from json import dumps as json_dumps
from re import escape as regex_escape from re import escape as regex_escape
from typing import Union, Tuple from typing import Union, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (CompletionRequest, from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest) ChatCompletionRequest)
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor) RegexLogitsProcessor,
CFGLogitsProcessor)
class GuidedDecodingMode(Enum): class GuidedDecodingMode(Enum):
JSON = "json" JSON = "json"
REGEX = "regex" REGEX = "regex"
CHOICE = "choice" CHOICE = "choice"
GRAMMAR = "grammar"
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""
global_thread_pool = None # used for generating logits processor fsm global_thread_pool = None # used for generating logits processor fsm
@ -57,9 +88,6 @@ def _get_guide_and_mode(
) -> Tuple[str, GuidedDecodingMode]: ) -> Tuple[str, GuidedDecodingMode]:
if request.guided_json: if request.guided_json:
if not isinstance(request.guided_json, (str, dict, BaseModel)):
raise TypeError("JSON schema must be str, dict, or BaseModel")
json = request.guided_json json = request.guided_json
if isinstance(json, dict): if isinstance(json, dict):
# turn dict into hashable string # turn dict into hashable string
@ -69,33 +97,33 @@ def _get_guide_and_mode(
# with the same fields will get hashed the same # with the same fields will get hashed the same
json = str(json.__signature__) json = str(json.__signature__)
return json, GuidedDecodingMode.JSON return json, GuidedDecodingMode.JSON
elif request.guided_regex: elif request.guided_regex:
if not isinstance(request.guided_regex, str):
raise TypeError("Regex must be string")
return request.guided_regex, GuidedDecodingMode.REGEX return request.guided_regex, GuidedDecodingMode.REGEX
elif request.guided_choice: elif request.guided_choice:
if not isinstance(request.guided_choice, list):
raise TypeError("Choices must be a list")
# choice just uses regex # choice just uses regex
choices = [ choices = [
regex_escape(str(choice)) for choice in request.guided_choice regex_escape(str(choice)) for choice in request.guided_choice
] ]
choices_regex = "(" + "|".join(choices) + ")" choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
else: else:
return None, None return None, None
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, tokenizer, def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode): mode: GuidedDecodingMode):
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer) return JSONLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer) return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR:
return CFGLogitsProcessor(guide, tokenizer)
else: else:
raise ValueError(f"Unknown guided decoding mode {mode}") raise ValueError(f"Unknown guided decoding mode {mode}")

View File

@ -16,30 +16,60 @@
import json import json
import math import math
from collections import defaultdict from collections import defaultdict
from typing import Union, DefaultDict, Dict, List, Optional from typing import Union, DefaultDict, Dict, List, Optional, Callable
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
from outlines.fsm.fsm import RegexFSM from transformers import PreTrainedTokenizerBase
from outlines.fsm.fsm import RegexFSM, CFGFSM
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
class RegexLogitsProcessor: class BaseLogitsProcessor:
def __init__(self, regex_string: str, tokenizer): def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation. """Adapt vLLM's tokenizer to use to compile the FSM.
Parameters The API of Outlines tokenizers is slightly different to that of
---------- `transformers`. The decoder of outlines, returns a list whereas
regex_string the decode of vLLM returns an str. To sync the vLLM decoder with
A string that represents a regular expression outlines internal api, the decoder should be adapted. In addition
tokenizer we need to handle the missing spaces to Llama's tokenizer to be
The model's tokenizer able to compile FSMs for this model.
""" """
tokenizer = self.adapt_tokenizer(tokenizer) if getattr(tokenizer, "_outlines_adapted", False):
fsm = RegexFSM(regex_string, tokenizer) return tokenizer
self.fsm = fsm
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]
return new_decoder
tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
return tokenizer
def init_state(self): def init_state(self):
"""Initialize the FSM states.""" """Initialize the FSM states."""
@ -69,38 +99,30 @@ class RegexLogitsProcessor:
return scores return scores
def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of class RegexLogitsProcessor(BaseLogitsProcessor):
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model. def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
""" """
tokenizer.vocabulary = tokenizer.get_vocab() tokenizer = self.adapt_tokenizer(tokenizer)
tokenizer.special_tokens = set(tokenizer.all_special_tokens) fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer
class JSONLogitsProcessor(RegexLogitsProcessor): class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, def __init__(self,
schema: Union[str, Dict, BaseModel], schema: Union[str, Dict, BaseModel],
tokenizer, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None): whitespace_pattern: Optional[str] = None):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
@ -130,3 +152,21 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
f"the JSON Schema specification") f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern) regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer) super().__init__(regex_string, tokenizer)
class CFGLogitsProcessor(BaseLogitsProcessor):
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the context free grammar generation.
Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm