[Bugfix] use diskcache in outlines _get_guide #5436 (#6203)

This commit is contained in:
Eric 2024-07-09 02:23:24 +08:00 committed by GitHub
parent 543aa48573
commit 185ad31f37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union from typing import Callable, DefaultDict, Dict, List, Union
import torch import torch
from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
@ -67,7 +68,7 @@ class BaseLogitsProcessor:
class RegexLogitsProcessor(BaseLogitsProcessor): class RegexLogitsProcessor(BaseLogitsProcessor):
@classmethod @classmethod
@lru_cache(maxsize=32) @cache()
def _get_guide(cls, regex_string: str, def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
@ -126,7 +127,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
class CFGLogitsProcessor(BaseLogitsProcessor): class CFGLogitsProcessor(BaseLogitsProcessor):
@classmethod @classmethod
@lru_cache(maxsize=32) @cache()
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer) return CFGGuide(cfg, tokenizer)