Allow user to define whitespace pattern for outlines (#4305)

This commit is contained in:
Robert Caulk 2024-05-01 05:48:39 +02:00 committed by GitHub
parent a822eb3413
commit c3845d82dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 8 deletions

View File

@ -57,7 +57,9 @@ def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA,
tokenizer,
whitespace_pattern=None)
regex_LP.init_state()
token_ids = tokenizer.encode(

View File

@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-chat-completion-extra-params
@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-completion-extra-params

View File

@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide,
tokenizer, mode)
tokenizer, mode,
request.guided_whitespace_pattern)
logits_processor = copy(result)
# reset logits processor's internal state
@ -117,9 +118,10 @@ def _get_guide_and_mode(
@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode):
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer)
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR:

View File

@ -18,7 +18,7 @@ import json
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union
from typing import Callable, DefaultDict, Dict, List, Union
import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self,
schema: Union[str, Dict, BaseModel],
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None):
whitespace_pattern: Union[str, None]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters