diff --git a/tests/engine/test_detokenize.py b/tests/engine/test_detokenize.py new file mode 100644 index 00000000..40590470 --- /dev/null +++ b/tests/engine/test_detokenize.py @@ -0,0 +1,55 @@ +import pytest + +from transformers import AutoTokenizer + +from vllm.transformers_utils.tokenizer import detokenize_incrementally + +TRUTH = [ + "Hello here, this is a simple test", + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", + "我很感谢你的热情" +] +TOKENIZERS = [ + "facebook/opt-125m", + "gpt2", + "bigcode/tiny_starcoder_py", + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", + "bigscience/bloom-560m", + "mosaicml/mpt-7b", + "tiiuae/falcon-7b", + "meta-llama/Llama-2-7b-hf", + "codellama/CodeLlama-7b-hf", +] + + +def _run_incremental_decode(tokenizer, all_input_ids): + decoded_text = "" + offset = 0 + token_offset = 0 + prev_tokens = None + for i in range(len(all_input_ids)): + new_tokens, text, offset, token_offset = detokenize_incrementally( + tokenizer, + all_input_ids[:i + 1], + prev_tokens, + offset, + token_offset, + skip_special_tokens=False) + decoded_text += text + if prev_tokens is None: + prev_tokens = new_tokens + else: + prev_tokens += new_tokens + return decoded_text + + +@pytest.mark.parametrize("truth", TRUTH) +@pytest.mark.parametrize("tokenizer_id", TOKENIZERS) +def test_decode_streaming(tokenizer_id, truth): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"] + + decoded_text = _run_incremental_decode(tokenizer, all_input_ids) + + assert decoded_text == truth diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 83dfbc2e..74093cf4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -623,15 +623,22 @@ class LLMEngine: def _decode_sequence(self, seq: Sequence) -> None: """Decodes the new token for a sequence.""" - new_token, new_output_text = detokenize_incrementally( - self.tokenizer, - seq.output_tokens, - seq.get_last_token_id(), - skip_special_tokens=True, - ) - if new_token is not None: - seq.output_tokens.append(new_token) - seq.output_text = new_output_text + (new_tokens, new_output_text, prefix_offset, + read_offset) = detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=True, + ) + if seq.tokens is None: + seq.tokens = new_tokens + else: + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_output_text def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: diff --git a/vllm/sequence.py b/vllm/sequence.py index 74682786..795397a3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -114,7 +114,6 @@ class Sequence: self.data = SequenceData(prompt_token_ids) self.output_logprobs: List[Dict[int, float]] = [] - self.output_tokens: List[str] = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] @@ -122,6 +121,12 @@ class Sequence: self._append_tokens_to_blocks(prompt_token_ids) self.status = SequenceStatus.WAITING + # Used for incremental detokenization + self.prefix_offset = 0 + self.read_offset = 0 + # Input + output tokens + self.tokens: Optional[List[str]] = None + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 4eb941d3..d1275a1d 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -67,33 +67,11 @@ def get_tokenizer( return tokenizer -def detokenize_incrementally( +def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - prev_output_tokens: List[str], - new_token_id: int, + output_tokens: List[str], skip_special_tokens: bool, -) -> Tuple[str, str]: - """Detokenizes the new token in conjunction with the previous output tokens. - - NOTE: This function does not update prev_output_tokens. - - Returns: - new_token: The new token as a string. - output_text: The new output text as a string. - """ - if skip_special_tokens and (new_token_id in tokenizer.all_special_ids): - return None, prev_output_tokens - new_token = tokenizer.convert_ids_to_tokens( - new_token_id, skip_special_tokens=skip_special_tokens) - output_tokens = prev_output_tokens + [new_token] - - # Convert the tokens to a string. - # Optimization: If the tokenizer does not have `added_tokens_encoder`, - # then we can directly use `convert_tokens_to_string`. - if not getattr(tokenizer, "added_tokens_encoder", {}): - output_text = tokenizer.convert_tokens_to_string(output_tokens) - return new_token, output_text - +) -> str: # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 # NOTE(woosuk): The following code is slow because it runs a for loop over @@ -115,5 +93,61 @@ def detokenize_incrementally( if current_sub_text: sub_text = tokenizer.convert_tokens_to_string(current_sub_text) sub_texts.append(sub_text) - output_text = " ".join(sub_texts) - return new_token, output_text + return " ".join(sub_texts) + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int = 0, + read_offset: int = 0, + skip_special_tokens: bool = False, +) -> Tuple[List[str], str, int, int]: + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + if prev_tokens is None: + new_tokens = tokenizer.convert_ids_to_tokens( + all_input_ids, skip_special_tokens=skip_special_tokens) + output_tokens = new_tokens + # 5 is an arbitrary value that should work for all + # tokenizers (bigger = more conservative). + # Subtract 1 extra to account for the generated token. + prefix_offset = max(len(output_tokens) - 6, 0) + read_offset = max(len(output_tokens) - 1, 0) + else: + new_token = tokenizer.convert_ids_to_tokens( + new_token_id, skip_special_tokens=skip_special_tokens) + new_tokens = [new_token] + output_tokens = prev_tokens + new_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if not getattr(tokenizer, "added_tokens_encoder", {}): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens) + + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) + else: + return new_tokens, "", prefix_offset, read_offset