From 2bcbae704c0d52913c6a2887260fc6bde6c20361 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 5 Nov 2024 21:28:29 -0700 Subject: [PATCH] [Bugfix] Fix edge-case crash when using chat with the Mistral Tekken Tokenizer (#10051) Signed-off-by: Travis Johnson --- tests/models/decoder_only/language/test_mistral.py | 9 ++++++--- vllm/transformers_utils/tokenizers/mistral.py | 8 ++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 5be44c54..6ec4b7e7 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -10,19 +10,22 @@ from ...utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", - "mistralai/Mistral-7B-Instruct-v0.3", - # Mistral-Nemo is to big for CI, but passes locally - # "mistralai/Mistral-Nemo-Instruct-2407" ] MISTRAL_FORMAT_MODELS = [ "mistralai/Mistral-7B-Instruct-v0.3", + # uses the v3-Tekken tokenizer + "mistralai/Ministral-8B-Instruct-2410", + # Mistral-Nemo is to big for CI, but passes locally + # "mistralai/Mistral-Nemo-Instruct-2407" ] SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) SYMBOLIC_LANG_PROMPTS = [ "勇敢な船乗りについての詩を書く", # japanese "寫一首關於勇敢的水手的詩", # chinese + "ပုံပြင်လေးပြောပြပါ်:\n", # burmese + "Repeat the phrase 'URGENCY🌶️':\nURGENCY🌶️\nURGENCY🌶️\n", # see https://github.com/vllm-project/vllm/pull/9625 ] # for function calling diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 896f70bc..ccffdcc2 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -254,7 +254,7 @@ class MistralTokenizer: skip_special_tokens: bool = True) -> str: assert ( skip_special_tokens - ), "Skipping special tokens is not supported for Mistral tokenizers." + ), "skip_special_tokens=False is not supported for Mistral tokenizers." if isinstance(ids, int): ids = [ids] @@ -268,12 +268,16 @@ class MistralTokenizer: # TODO(Patrick) - potentially allow special tokens to not be skipped assert ( skip_special_tokens - ), "Skipping special tokens is not supported for Mistral tokenizers." + ), "skip_special_tokens=False is not supported for Mistral tokenizers." assert isinstance(self.tokenizer, (Tekkenizer, SentencePieceTokenizer)), type( self.tokenizer) + if isinstance(self.tokenizer, Tekkenizer): + # skip special tokens + ids = [i for i in ids if i > self.tokenizer.num_special_tokens] + tokens = [self.tokenizer.id_to_piece(id) for id in ids] if any("�" in t for t in tokens):