Allow user to define whitespace pattern for outlines (#4305)
This commit is contained in:
parent
a822eb3413
commit
c3845d82dc
@ -57,7 +57,9 @@ def test_guided_logits_processors():
|
||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
|
||||
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
|
||||
json_LP = JSONLogitsProcessor(TEST_SCHEMA,
|
||||
tokenizer,
|
||||
whitespace_pattern=None)
|
||||
|
||||
regex_LP.init_state()
|
||||
token_ids = tokenizer.encode(
|
||||
|
||||
@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
|
||||
@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
|
||||
result = await loop.run_in_executor(global_thread_pool,
|
||||
_get_cached_logits_processor, guide,
|
||||
tokenizer, mode)
|
||||
tokenizer, mode,
|
||||
request.guided_whitespace_pattern)
|
||||
|
||||
logits_processor = copy(result)
|
||||
# reset logits processor's internal state
|
||||
@ -117,9 +118,10 @@ def _get_guide_and_mode(
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_cached_logits_processor(guide: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
mode: GuidedDecodingMode):
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None]):
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer)
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer)
|
||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||
|
||||
@ -18,7 +18,7 @@ import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||
from typing import Callable, DefaultDict, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
||||
@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
|
||||
def __init__(self,
|
||||
schema: Union[str, Dict, BaseModel],
|
||||
def __init__(self, schema: Union[str, Dict, BaseModel],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Optional[str] = None):
|
||||
whitespace_pattern: Union[str, None]):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
|
||||
Parameters
|
||||
|
||||
Loading…
Reference in New Issue
Block a user