[Bugfix] Add missing attributes in mistral tokenizer (#8364)
This commit is contained in:
parent
aea02f30de
commit
7015417fd4
@ -519,11 +519,14 @@ def apply_hf_chat_template(
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str],
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[int]:
|
||||
if chat_template is not None:
|
||||
logger.warning(
|
||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -45,26 +45,25 @@ class MistralTokenizer:
|
||||
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
|
||||
self.mistral = tokenizer
|
||||
self.instruct = tokenizer.instruct_tokenizer
|
||||
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
|
||||
|
||||
self.vocab_size = len(self.tokenizer.vocab())
|
||||
|
||||
assert isinstance(self.tokenizer,
|
||||
(Tekkenizer, SentencePieceTokenizer)), type(
|
||||
self.tokenizer)
|
||||
|
||||
if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
|
||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||
if isinstance(tokenizer_, Tekkenizer):
|
||||
# Make sure special tokens will not raise
|
||||
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||
|
||||
self._is_tekken = is_tekken
|
||||
self._vocab = {
|
||||
token: idx
|
||||
for idx, token in enumerate(tokenizer_.vocab())
|
||||
}
|
||||
elif isinstance(tokenizer_, SentencePieceTokenizer):
|
||||
self._vocab = {
|
||||
token: idx
|
||||
for idx, token in enumerate(tokenizer_.vocab())
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||
|
||||
# the following attributes are set to fit VLLM's design
|
||||
self.is_fast = True
|
||||
self.chat_template = True
|
||||
self.all_special_ids: List[Any] = []
|
||||
self.all_special_tokens: List[Any] = []
|
||||
self.all_special_tokens_extended: List[Any] = []
|
||||
self.tokenizer = tokenizer_
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
@ -102,6 +101,38 @@ class MistralTokenizer:
|
||||
revision=revision)
|
||||
return tokenizer_file
|
||||
|
||||
# the following attributes are set to fit VLLM's design
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> List[int]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self.tokenizer.bos_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self._vocab)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -117,9 +148,12 @@ class MistralTokenizer:
|
||||
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_added_vocab(self) -> List[str]:
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self._vocab
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
# Mistral tokenizers have no added vocabulary
|
||||
return []
|
||||
return {}
|
||||
|
||||
def encode(self, prompt: str) -> List[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
@ -141,7 +175,7 @@ class MistralTokenizer:
|
||||
return encoded.tokens
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
if self._is_tekken:
|
||||
if isinstance(self.tokenizer, Tekkenizer):
|
||||
return "".join(tokens)
|
||||
else:
|
||||
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
||||
@ -151,14 +185,11 @@ class MistralTokenizer:
|
||||
ids = [ids]
|
||||
return self.tokenizer.decode(ids)
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
skip_special_tokens: Optional[bool] = True) -> List[str]:
|
||||
self,
|
||||
ids: List[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
assert (
|
||||
skip_special_tokens
|
||||
@ -170,6 +201,3 @@ class MistralTokenizer:
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
return tokens
|
||||
|
||||
def __len__(self):
|
||||
return self.vocab_size
|
||||
|
||||
Loading…
Reference in New Issue
Block a user