[Bugfix] Add missing attributes in mistral tokenizer (#8364)

This commit is contained in:
Cyrus Leung 2024-09-12 02:36:54 +08:00 committed by GitHub
parent aea02f30de
commit 7015417fd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 31 deletions

View File

@ -519,11 +519,14 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
chat_template: Optional[str] = None,
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")
return tokenizer.apply_chat_template(
messages=messages,
chat_template=chat_template,
**kwargs,
)

View File

@ -45,26 +45,25 @@ class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
self.vocab_size = len(self.tokenizer.vocab())
assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
if isinstance(tokenizer_, Tekkenizer):
# Make sure special tokens will not raise
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
self._is_tekken = is_tekken
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
elif isinstance(tokenizer_, SentencePieceTokenizer):
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
# the following attributes are set to fit VLLM's design
self.is_fast = True
self.chat_template = True
self.all_special_ids: List[Any] = []
self.all_special_tokens: List[Any] = []
self.all_special_tokens_extended: List[Any] = []
self.tokenizer = tokenizer_
@classmethod
def from_pretrained(cls,
@ -102,6 +101,38 @@ class MistralTokenizer:
revision=revision)
return tokenizer_file
# the following attributes are set to fit VLLM's design
@property
def all_special_tokens_extended(self) -> List[str]:
return []
@property
def all_special_tokens(self) -> List[str]:
return []
@property
def all_special_ids(self) -> List[int]:
return []
@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_id
@property
def is_fast(self) -> bool:
return True
@property
def vocab_size(self) -> int:
return len(self._vocab)
def __len__(self) -> int:
return self.vocab_size
def __call__(
self,
prompt: str,
@ -117,9 +148,12 @@ class MistralTokenizer:
return Encoding(input_ids=input_ids)
def get_added_vocab(self) -> List[str]:
def get_vocab(self) -> Dict[str, int]:
return self._vocab
def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
return []
return {}
def encode(self, prompt: str) -> List[int]:
# `encode` should only be used for prompt completion
@ -141,7 +175,7 @@ class MistralTokenizer:
return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self._is_tekken:
if isinstance(self.tokenizer, Tekkenizer):
return "".join(tokens)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
@ -151,14 +185,11 @@ class MistralTokenizer:
ids = [ids]
return self.tokenizer.decode(ids)
@property
def eos_token_id(self):
return self.tokenizer.eos_id
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: Optional[bool] = True) -> List[str]:
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens
@ -170,6 +201,3 @@ class MistralTokenizer:
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
return tokens
def __len__(self):
return self.vocab_size