Allow user to define whitespace pattern for outlines (#4305)

This commit is contained in:
Robert Caulk 2024-05-01 05:48:39 +02:00 committed by GitHub
parent a822eb3413
commit c3845d82dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 8 deletions

View File

@ -57,7 +57,9 @@ def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) json_LP = JSONLogitsProcessor(TEST_SCHEMA,
tokenizer,
whitespace_pattern=None)
regex_LP.init_state() regex_LP.init_state()
token_ids = tokenizer.encode( token_ids = tokenizer.encode(

View File

@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either " "of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of " "of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-completion-extra-params # doc: end-completion-extra-params

View File

@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
result = await loop.run_in_executor(global_thread_pool, result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide, _get_cached_logits_processor, guide,
tokenizer, mode) tokenizer, mode,
request.guided_whitespace_pattern)
logits_processor = copy(result) logits_processor = copy(result)
# reset logits processor's internal state # reset logits processor's internal state
@ -117,9 +118,10 @@ def _get_guide_and_mode(
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode): mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer) return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer) return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR: elif mode == GuidedDecodingMode.GRAMMAR:

View File

@ -18,7 +18,7 @@ import json
import math import math
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union from typing import Callable, DefaultDict, Dict, List, Union
import torch import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class JSONLogitsProcessor(RegexLogitsProcessor): class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, def __init__(self, schema: Union[str, Dict, BaseModel],
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None): whitespace_pattern: Union[str, None]):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
Parameters Parameters