[Bugfix] Add missing attributes in mistral tokenizer (#8364)

This commit is contained in:
Cyrus Leung 2024-09-12 02:36:54 +08:00 committed by GitHub
parent aea02f30de
commit 7015417fd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 31 deletions

View File

@ -519,11 +519,14 @@ def apply_hf_chat_template(
def apply_mistral_chat_template( def apply_mistral_chat_template(
tokenizer: MistralTokenizer, tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
chat_template: Optional[str], chat_template: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> List[int]: ) -> List[int]:
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
messages=messages, messages=messages,
chat_template=chat_template,
**kwargs, **kwargs,
) )

View File

@ -45,26 +45,25 @@ class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None: def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer self.instruct = tokenizer.instruct_tokenizer
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
self.vocab_size = len(self.tokenizer.vocab()) tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
if isinstance(tokenizer_, Tekkenizer):
assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
# Make sure special tokens will not raise # Make sure special tokens will not raise
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
self._is_tekken = is_tekken self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
elif isinstance(tokenizer_, SentencePieceTokenizer):
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
# the following attributes are set to fit VLLM's design self.tokenizer = tokenizer_
self.is_fast = True
self.chat_template = True
self.all_special_ids: List[Any] = []
self.all_special_tokens: List[Any] = []
self.all_special_tokens_extended: List[Any] = []
@classmethod @classmethod
def from_pretrained(cls, def from_pretrained(cls,
@ -102,6 +101,38 @@ class MistralTokenizer:
revision=revision) revision=revision)
return tokenizer_file return tokenizer_file
# the following attributes are set to fit VLLM's design
@property
def all_special_tokens_extended(self) -> List[str]:
return []
@property
def all_special_tokens(self) -> List[str]:
return []
@property
def all_special_ids(self) -> List[int]:
return []
@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_id
@property
def is_fast(self) -> bool:
return True
@property
def vocab_size(self) -> int:
return len(self._vocab)
def __len__(self) -> int:
return self.vocab_size
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
@ -117,9 +148,12 @@ class MistralTokenizer:
return Encoding(input_ids=input_ids) return Encoding(input_ids=input_ids)
def get_added_vocab(self) -> List[str]: def get_vocab(self) -> Dict[str, int]:
return self._vocab
def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary # Mistral tokenizers have no added vocabulary
return [] return {}
def encode(self, prompt: str) -> List[int]: def encode(self, prompt: str) -> List[int]:
# `encode` should only be used for prompt completion # `encode` should only be used for prompt completion
@ -141,7 +175,7 @@ class MistralTokenizer:
return encoded.tokens return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self._is_tekken: if isinstance(self.tokenizer, Tekkenizer):
return "".join(tokens) return "".join(tokens)
else: else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type] return self.tokenizer.decode(tokens) # type: ignore[arg-type]
@ -151,14 +185,11 @@ class MistralTokenizer:
ids = [ids] ids = [ids]
return self.tokenizer.decode(ids) return self.tokenizer.decode(ids)
@property
def eos_token_id(self):
return self.tokenizer.eos_id
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: List[int], ids: List[int],
skip_special_tokens: Optional[bool] = True) -> List[str]: skip_special_tokens: bool = True,
) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped # TODO(Patrick) - potentially allow special tokens to not be skipped
assert ( assert (
skip_special_tokens skip_special_tokens
@ -170,6 +201,3 @@ class MistralTokenizer:
tokens = [self.tokenizer.id_to_piece(id) for id in ids] tokens = [self.tokenizer.id_to_piece(id) for id in ids]
return tokens return tokens
def __len__(self):
return self.vocab_size