parent
543aa48573
commit
185ad31f37
@ -21,6 +21,7 @@ from functools import lru_cache
|
|||||||
from typing import Callable, DefaultDict, Dict, List, Union
|
from typing import Callable, DefaultDict, Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from outlines.caching import cache
|
||||||
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
|
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -67,7 +68,7 @@ class BaseLogitsProcessor:
|
|||||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@lru_cache(maxsize=32)
|
@cache()
|
||||||
def _get_guide(cls, regex_string: str,
|
def _get_guide(cls, regex_string: str,
|
||||||
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
@ -126,7 +127,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
|||||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@lru_cache(maxsize=32)
|
@cache()
|
||||||
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||||
tokenizer = _adapt_tokenizer(tokenizer)
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
return CFGGuide(cfg, tokenizer)
|
return CFGGuide(cfg, tokenizer)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user