Allow user to define whitespace pattern for outlines (#4305)
This commit is contained in:
parent
a822eb3413
commit
c3845d82dc
@ -57,7 +57,9 @@ def test_guided_logits_processors():
|
|||||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||||
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
|
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
|
||||||
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
|
json_LP = JSONLogitsProcessor(TEST_SCHEMA,
|
||||||
|
tokenizer,
|
||||||
|
whitespace_pattern=None)
|
||||||
|
|
||||||
regex_LP.init_state()
|
regex_LP.init_state()
|
||||||
token_ids = tokenizer.encode(
|
token_ids = tokenizer.encode(
|
||||||
|
|||||||
@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
"If specified, will override the default guided decoding backend "
|
"If specified, will override the default guided decoding backend "
|
||||||
"of the server for this specific request. If set, must be either "
|
"of the server for this specific request. If set, must be either "
|
||||||
"'outlines' / 'lm-format-enforcer'"))
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
guided_whitespace_pattern: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default whitespace pattern "
|
||||||
|
"for guided json decoding."))
|
||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
"If specified, will override the default guided decoding backend "
|
"If specified, will override the default guided decoding backend "
|
||||||
"of the server for this specific request. If set, must be one of "
|
"of the server for this specific request. If set, must be one of "
|
||||||
"'outlines' / 'lm-format-enforcer'"))
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
guided_whitespace_pattern: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default whitespace pattern "
|
||||||
|
"for guided json decoding."))
|
||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
|
|||||||
@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
result = await loop.run_in_executor(global_thread_pool,
|
result = await loop.run_in_executor(global_thread_pool,
|
||||||
_get_cached_logits_processor, guide,
|
_get_cached_logits_processor, guide,
|
||||||
tokenizer, mode)
|
tokenizer, mode,
|
||||||
|
request.guided_whitespace_pattern)
|
||||||
|
|
||||||
logits_processor = copy(result)
|
logits_processor = copy(result)
|
||||||
# reset logits processor's internal state
|
# reset logits processor's internal state
|
||||||
@ -117,9 +118,10 @@ def _get_guide_and_mode(
|
|||||||
@lru_cache(maxsize=32)
|
@lru_cache(maxsize=32)
|
||||||
def _get_cached_logits_processor(guide: str,
|
def _get_cached_logits_processor(guide: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
mode: GuidedDecodingMode):
|
mode: GuidedDecodingMode,
|
||||||
|
whitespace_pattern: Union[str, None]):
|
||||||
if mode == GuidedDecodingMode.JSON:
|
if mode == GuidedDecodingMode.JSON:
|
||||||
return JSONLogitsProcessor(guide, tokenizer)
|
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
||||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||||
return RegexLogitsProcessor(guide, tokenizer)
|
return RegexLogitsProcessor(guide, tokenizer)
|
||||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
from typing import Callable, DefaultDict, Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
||||||
@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
|||||||
|
|
||||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, schema: Union[str, Dict, BaseModel],
|
||||||
schema: Union[str, Dict, BaseModel],
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
whitespace_pattern: Optional[str] = None):
|
whitespace_pattern: Union[str, None]):
|
||||||
"""Compile the FSM that drives the JSON-guided generation.
|
"""Compile the FSM that drives the JSON-guided generation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user