Support arbitrary json_object in OpenAI and Context Free Grammar (#3211)
This commit is contained in:
parent
8e67598aa6
commit
120157fd2a
@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
|
|||||||
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
|
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||||
|
resp = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": ('what is 1+1? please respond with a JSON object, '
|
||||||
|
'the format is {"result": 2}')
|
||||||
|
}],
|
||||||
|
response_format={"type": "json_object"})
|
||||||
|
|
||||||
|
content = resp.choices[0].message.content
|
||||||
|
loaded = json.loads(content)
|
||||||
|
assert loaded == {"result": 2}, loaded
|
||||||
|
|
||||||
|
|
||||||
|
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||||
|
simple_sql_grammar = """
|
||||||
|
start: select_statement
|
||||||
|
|
||||||
|
select_statement: "SELECT" column "from" table "where" condition
|
||||||
|
|
||||||
|
column: "col_1" | "col_2"
|
||||||
|
table: "table_1" | "table_2"
|
||||||
|
condition: column "=" number
|
||||||
|
|
||||||
|
number: "1" | "2"
|
||||||
|
"""
|
||||||
|
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=("Generate a sql state that select col_1 from "
|
||||||
|
"table_1 where it is equals to 1"),
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=500,
|
||||||
|
extra_body=dict(guided_grammar=simple_sql_grammar))
|
||||||
|
|
||||||
|
content = completion.choices[0].text
|
||||||
|
|
||||||
|
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||||
|
from lark import Lark
|
||||||
|
parser = Lark(simple_sql_grammar)
|
||||||
|
parser.parse(content)
|
||||||
|
|
||||||
|
# remove spaces for comparison b/c we removed them in the grammar
|
||||||
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
|
||||||
|
|
||||||
|
assert content.strip() == ground_truth
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
|
|||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(BaseModel):
|
||||||
|
# type must be "json_object" or "text"
|
||||||
|
type: str = Literal["text", "json_object"]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Dict[str, str]]
|
messages: List[Dict[str, str]]
|
||||||
@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||||
guided_regex: Optional[str] = None
|
guided_regex: Optional[str] = None
|
||||||
guided_choice: Optional[List[str]] = None
|
guided_choice: Optional[List[str]] = None
|
||||||
|
guided_grammar: Optional[str] = None
|
||||||
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
def to_sampling_params(self) -> SamplingParams:
|
def to_sampling_params(self) -> SamplingParams:
|
||||||
if self.logprobs and not self.top_logprobs:
|
if self.logprobs and not self.top_logprobs:
|
||||||
@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
|
|||||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||||
guided_regex: Optional[str] = None
|
guided_regex: Optional[str] = None
|
||||||
guided_choice: Optional[List[str]] = None
|
guided_choice: Optional[List[str]] = None
|
||||||
|
guided_grammar: Optional[str] = None
|
||||||
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
def to_sampling_params(self):
|
def to_sampling_params(self):
|
||||||
echo_without_generation = self.echo and self.max_tokens == 0
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
|
|||||||
@ -6,20 +6,51 @@ from functools import lru_cache
|
|||||||
from json import dumps as json_dumps
|
from json import dumps as json_dumps
|
||||||
from re import escape as regex_escape
|
from re import escape as regex_escape
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||||
ChatCompletionRequest)
|
ChatCompletionRequest)
|
||||||
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
|
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
|
||||||
RegexLogitsProcessor)
|
RegexLogitsProcessor,
|
||||||
|
CFGLogitsProcessor)
|
||||||
|
|
||||||
|
|
||||||
class GuidedDecodingMode(Enum):
|
class GuidedDecodingMode(Enum):
|
||||||
JSON = "json"
|
JSON = "json"
|
||||||
REGEX = "regex"
|
REGEX = "regex"
|
||||||
CHOICE = "choice"
|
CHOICE = "choice"
|
||||||
|
GRAMMAR = "grammar"
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
|
||||||
|
# the main difference is that we changed the start: value to
|
||||||
|
# start: object | array, so we are denying scalar values as the root of the
|
||||||
|
# JSON. Starting with scalars as the root seems to cause llama to generate
|
||||||
|
# without stop.
|
||||||
|
JSON_GRAMMAR = r"""
|
||||||
|
?start: object | array
|
||||||
|
|
||||||
|
?value: object
|
||||||
|
| array
|
||||||
|
| UNESCAPED_STRING
|
||||||
|
| SIGNED_NUMBER -> number
|
||||||
|
| "true" -> true
|
||||||
|
| "false" -> false
|
||||||
|
| "null" -> null
|
||||||
|
|
||||||
|
array : "[" [value ("," value)*] "]"
|
||||||
|
object : "{" [pair ("," pair)*] "}"
|
||||||
|
pair : UNESCAPED_STRING ":" value
|
||||||
|
|
||||||
|
%import common.UNESCAPED_STRING
|
||||||
|
%import common.SIGNED_NUMBER
|
||||||
|
%import common.WS
|
||||||
|
|
||||||
|
%ignore WS
|
||||||
|
"""
|
||||||
|
|
||||||
global_thread_pool = None # used for generating logits processor fsm
|
global_thread_pool = None # used for generating logits processor fsm
|
||||||
|
|
||||||
|
|
||||||
@ -57,9 +88,6 @@ def _get_guide_and_mode(
|
|||||||
) -> Tuple[str, GuidedDecodingMode]:
|
) -> Tuple[str, GuidedDecodingMode]:
|
||||||
|
|
||||||
if request.guided_json:
|
if request.guided_json:
|
||||||
if not isinstance(request.guided_json, (str, dict, BaseModel)):
|
|
||||||
raise TypeError("JSON schema must be str, dict, or BaseModel")
|
|
||||||
|
|
||||||
json = request.guided_json
|
json = request.guided_json
|
||||||
if isinstance(json, dict):
|
if isinstance(json, dict):
|
||||||
# turn dict into hashable string
|
# turn dict into hashable string
|
||||||
@ -69,33 +97,33 @@ def _get_guide_and_mode(
|
|||||||
# with the same fields will get hashed the same
|
# with the same fields will get hashed the same
|
||||||
json = str(json.__signature__)
|
json = str(json.__signature__)
|
||||||
return json, GuidedDecodingMode.JSON
|
return json, GuidedDecodingMode.JSON
|
||||||
|
|
||||||
elif request.guided_regex:
|
elif request.guided_regex:
|
||||||
if not isinstance(request.guided_regex, str):
|
|
||||||
raise TypeError("Regex must be string")
|
|
||||||
return request.guided_regex, GuidedDecodingMode.REGEX
|
return request.guided_regex, GuidedDecodingMode.REGEX
|
||||||
|
|
||||||
elif request.guided_choice:
|
elif request.guided_choice:
|
||||||
if not isinstance(request.guided_choice, list):
|
|
||||||
raise TypeError("Choices must be a list")
|
|
||||||
|
|
||||||
# choice just uses regex
|
# choice just uses regex
|
||||||
choices = [
|
choices = [
|
||||||
regex_escape(str(choice)) for choice in request.guided_choice
|
regex_escape(str(choice)) for choice in request.guided_choice
|
||||||
]
|
]
|
||||||
choices_regex = "(" + "|".join(choices) + ")"
|
choices_regex = "(" + "|".join(choices) + ")"
|
||||||
return choices_regex, GuidedDecodingMode.CHOICE
|
return choices_regex, GuidedDecodingMode.CHOICE
|
||||||
|
elif request.guided_grammar:
|
||||||
|
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
|
||||||
|
elif (request.response_format is not None
|
||||||
|
and request.response_format.type == "json_object"):
|
||||||
|
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||||
else:
|
else:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=32)
|
@lru_cache(maxsize=32)
|
||||||
def _get_cached_logits_processor(guide: str, tokenizer,
|
def _get_cached_logits_processor(guide: str,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
mode: GuidedDecodingMode):
|
mode: GuidedDecodingMode):
|
||||||
if mode == GuidedDecodingMode.JSON:
|
if mode == GuidedDecodingMode.JSON:
|
||||||
return JSONLogitsProcessor(guide, tokenizer)
|
return JSONLogitsProcessor(guide, tokenizer)
|
||||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||||
return RegexLogitsProcessor(guide, tokenizer)
|
return RegexLogitsProcessor(guide, tokenizer)
|
||||||
|
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||||
|
return CFGLogitsProcessor(guide, tokenizer)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||||
|
|||||||
@ -16,30 +16,60 @@
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Union, DefaultDict, Dict, List, Optional
|
from typing import Union, DefaultDict, Dict, List, Optional, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from outlines.fsm.fsm import RegexFSM, CFGFSM
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
|
|
||||||
|
|
||||||
class RegexLogitsProcessor:
|
class BaseLogitsProcessor:
|
||||||
|
|
||||||
def __init__(self, regex_string: str, tokenizer):
|
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
|
||||||
"""Compile the FSM that drives the regex-structured generation.
|
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||||
|
|
||||||
Parameters
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
----------
|
`transformers`. The decoder of outlines, returns a list whereas
|
||||||
regex_string
|
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||||
A string that represents a regular expression
|
outlines internal api, the decoder should be adapted. In addition
|
||||||
tokenizer
|
we need to handle the missing spaces to Llama's tokenizer to be
|
||||||
The model's tokenizer
|
able to compile FSMs for this model.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
if getattr(tokenizer, "_outlines_adapted", False):
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
return tokenizer
|
||||||
self.fsm = fsm
|
|
||||||
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
|
return " " + string
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
def change_decoder(
|
||||||
|
decoder: Callable[[List[int]], str]
|
||||||
|
) -> Callable[[List[int]], List[str]]:
|
||||||
|
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||||
|
|
||||||
|
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||||
|
return [decoder(inp_tokens)]
|
||||||
|
|
||||||
|
return new_decoder
|
||||||
|
|
||||||
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||||
|
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
def init_state(self):
|
def init_state(self):
|
||||||
"""Initialize the FSM states."""
|
"""Initialize the FSM states."""
|
||||||
@ -69,38 +99,30 @@ class RegexLogitsProcessor:
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
|
||||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
|
||||||
|
|
||||||
The API of Outlines tokenizers is slightly different to that of
|
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||||
`transformers`. In addition we need to handle the missing spaces to
|
|
||||||
Llama's tokenizer to be able to compile FSMs for this model.
|
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
"""Compile the FSM that drives the regex-structured generation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
regex_string
|
||||||
|
A string that represents a regular expression
|
||||||
|
tokenizer
|
||||||
|
The model's tokenizer
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
|
self.fsm = fsm
|
||||||
def convert_token_to_string(token: str) -> str:
|
|
||||||
from transformers.file_utils import SPIECE_UNDERLINE
|
|
||||||
|
|
||||||
string = tokenizer.convert_tokens_to_string([token])
|
|
||||||
|
|
||||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
|
||||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
|
||||||
return " " + string
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
schema: Union[str, Dict, BaseModel],
|
schema: Union[str, Dict, BaseModel],
|
||||||
tokenizer,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
whitespace_pattern: Optional[str] = None):
|
whitespace_pattern: Optional[str] = None):
|
||||||
"""Compile the FSM that drives the JSON-guided generation.
|
"""Compile the FSM that drives the JSON-guided generation.
|
||||||
|
|
||||||
@ -130,3 +152,21 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
|||||||
f"the JSON Schema specification")
|
f"the JSON Schema specification")
|
||||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||||
super().__init__(regex_string, tokenizer)
|
super().__init__(regex_string, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||||
|
|
||||||
|
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
"""Compile the FSM that drives the context free grammar generation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cfg
|
||||||
|
A string that represents a context-free grammar
|
||||||
|
tokenizer
|
||||||
|
The model's tokenizer
|
||||||
|
|
||||||
|
"""
|
||||||
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
fsm = CFGFSM(cfg, tokenizer)
|
||||||
|
self.fsm = fsm
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user