diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 26f90456..174b905d 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,7 +4,7 @@ Run `pytest tests/models/test_mistral.py`. """ import pytest -from vllm import SamplingParams +from vllm import LLM, SamplingParams from ...utils import check_logprobs_close @@ -16,6 +16,10 @@ MODELS = [ ] SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) +SYMBOLIC_LANG_PROMPTS = [ + "勇敢な船乗りについての詩を書く", # japanese + "寫一首關於勇敢的水手的詩", # chinese +] # for function calling TOOLS = [{ @@ -131,6 +135,26 @@ def test_mistral_format( ) +@pytest.mark.parametrize("model", MODELS[1:]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("prompt", SYMBOLIC_LANG_PROMPTS) +def test_mistral_symbolic_languages( + model: str, + dtype: str, + prompt: str, +) -> None: + prompt = "hi" + msg = {"role": "user", "content": prompt} + llm = LLM(model=model, + dtype=dtype, + max_model_len=8192, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") + outputs = llm.chat([msg], sampling_params=SAMPLING_PARAMS) + assert "�" not in outputs[0].outputs[0].text.strip() + + @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling def test_mistral_function_calling( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 7a228a3e..78813305 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -175,10 +175,29 @@ class MistralTokenizer: def convert_tokens_to_string(self, tokens: List[str]) -> str: if isinstance(self.tokenizer, Tekkenizer): - return "".join(t for t in tokens - if t not in self.tokenizer._all_special_tokens) + tokens = [ + t for t in tokens + if t not in self.tokenizer._all_special_tokens + ] + + if any(isinstance(t, bytes) for t in tokens): + # we need to encode and decode all tokens again + shift = self.tokenizer.num_special_tokens + byte_tokens = [ + t.encode("utf-8") if not isinstance(t, bytes) else t + for t in tokens + ] + ids = [ + self.tokenizer._tekken_token2id_nospecial[t] + shift + for t in byte_tokens + ] + decoded = self.tokenizer.decode(ids) + else: + decoded = "".join(tokens) else: - return self.tokenizer.decode(tokens) # type: ignore[arg-type] + decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type] + + return decoded def decode(self, ids: Union[List[int], int]) -> str: if isinstance(ids, int): @@ -200,4 +219,11 @@ class MistralTokenizer: self.tokenizer) tokens = [self.tokenizer.id_to_piece(id) for id in ids] + + if any(t.strip() == "�" for t in tokens): + # if any stripped decoded token is undefined + # because it's invalid unicode then pass bytes + # See: https://github.com/vllm-project/vllm/pull/8640 + tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + return tokens