Use TGI-like incremental detokenization (#984)
This commit is contained in:
parent
3272d7a0b7
commit
9841d48a10
55
tests/engine/test_detokenize.py
Normal file
55
tests/engine/test_detokenize.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||||
|
|
||||||
|
TRUTH = [
|
||||||
|
"Hello here, this is a simple test",
|
||||||
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||||
|
"我很感谢你的热情"
|
||||||
|
]
|
||||||
|
TOKENIZERS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"gpt2",
|
||||||
|
"bigcode/tiny_starcoder_py",
|
||||||
|
"EleutherAI/gpt-j-6b",
|
||||||
|
"EleutherAI/pythia-70m",
|
||||||
|
"bigscience/bloom-560m",
|
||||||
|
"mosaicml/mpt-7b",
|
||||||
|
"tiiuae/falcon-7b",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
"codellama/CodeLlama-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_incremental_decode(tokenizer, all_input_ids):
|
||||||
|
decoded_text = ""
|
||||||
|
offset = 0
|
||||||
|
token_offset = 0
|
||||||
|
prev_tokens = None
|
||||||
|
for i in range(len(all_input_ids)):
|
||||||
|
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||||
|
tokenizer,
|
||||||
|
all_input_ids[:i + 1],
|
||||||
|
prev_tokens,
|
||||||
|
offset,
|
||||||
|
token_offset,
|
||||||
|
skip_special_tokens=False)
|
||||||
|
decoded_text += text
|
||||||
|
if prev_tokens is None:
|
||||||
|
prev_tokens = new_tokens
|
||||||
|
else:
|
||||||
|
prev_tokens += new_tokens
|
||||||
|
return decoded_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("truth", TRUTH)
|
||||||
|
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||||
|
def test_decode_streaming(tokenizer_id, truth):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||||
|
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
decoded_text = _run_incremental_decode(tokenizer, all_input_ids)
|
||||||
|
|
||||||
|
assert decoded_text == truth
|
||||||
@ -623,15 +623,22 @@ class LLMEngine:
|
|||||||
|
|
||||||
def _decode_sequence(self, seq: Sequence) -> None:
|
def _decode_sequence(self, seq: Sequence) -> None:
|
||||||
"""Decodes the new token for a sequence."""
|
"""Decodes the new token for a sequence."""
|
||||||
new_token, new_output_text = detokenize_incrementally(
|
(new_tokens, new_output_text, prefix_offset,
|
||||||
self.tokenizer,
|
read_offset) = detokenize_incrementally(
|
||||||
seq.output_tokens,
|
self.tokenizer,
|
||||||
seq.get_last_token_id(),
|
all_input_ids=seq.get_token_ids(),
|
||||||
skip_special_tokens=True,
|
prev_tokens=seq.tokens,
|
||||||
)
|
prefix_offset=seq.prefix_offset,
|
||||||
if new_token is not None:
|
read_offset=seq.read_offset,
|
||||||
seq.output_tokens.append(new_token)
|
skip_special_tokens=True,
|
||||||
seq.output_text = new_output_text
|
)
|
||||||
|
if seq.tokens is None:
|
||||||
|
seq.tokens = new_tokens
|
||||||
|
else:
|
||||||
|
seq.tokens.extend(new_tokens)
|
||||||
|
seq.prefix_offset = prefix_offset
|
||||||
|
seq.read_offset = read_offset
|
||||||
|
seq.output_text += new_output_text
|
||||||
|
|
||||||
def _check_stop(self, seq: Sequence,
|
def _check_stop(self, seq: Sequence,
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
|
|||||||
@ -114,7 +114,6 @@ class Sequence:
|
|||||||
|
|
||||||
self.data = SequenceData(prompt_token_ids)
|
self.data = SequenceData(prompt_token_ids)
|
||||||
self.output_logprobs: List[Dict[int, float]] = []
|
self.output_logprobs: List[Dict[int, float]] = []
|
||||||
self.output_tokens: List[str] = []
|
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
|
|
||||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||||
@ -122,6 +121,12 @@ class Sequence:
|
|||||||
self._append_tokens_to_blocks(prompt_token_ids)
|
self._append_tokens_to_blocks(prompt_token_ids)
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
|
|
||||||
|
# Used for incremental detokenization
|
||||||
|
self.prefix_offset = 0
|
||||||
|
self.read_offset = 0
|
||||||
|
# Input + output tokens
|
||||||
|
self.tokens: Optional[List[str]] = None
|
||||||
|
|
||||||
def _append_logical_block(self) -> None:
|
def _append_logical_block(self) -> None:
|
||||||
block = LogicalTokenBlock(
|
block = LogicalTokenBlock(
|
||||||
block_number=len(self.logical_token_blocks),
|
block_number=len(self.logical_token_blocks),
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
@ -67,33 +67,11 @@ def get_tokenizer(
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def detokenize_incrementally(
|
def _convert_tokens_to_string_with_added_encoders(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
prev_output_tokens: List[str],
|
output_tokens: List[str],
|
||||||
new_token_id: int,
|
|
||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
) -> Tuple[str, str]:
|
) -> str:
|
||||||
"""Detokenizes the new token in conjunction with the previous output tokens.
|
|
||||||
|
|
||||||
NOTE: This function does not update prev_output_tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
new_token: The new token as a string.
|
|
||||||
output_text: The new output text as a string.
|
|
||||||
"""
|
|
||||||
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
|
|
||||||
return None, prev_output_tokens
|
|
||||||
new_token = tokenizer.convert_ids_to_tokens(
|
|
||||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
|
||||||
output_tokens = prev_output_tokens + [new_token]
|
|
||||||
|
|
||||||
# Convert the tokens to a string.
|
|
||||||
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
|
|
||||||
# then we can directly use `convert_tokens_to_string`.
|
|
||||||
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
|
||||||
output_text = tokenizer.convert_tokens_to_string(output_tokens)
|
|
||||||
return new_token, output_text
|
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||||
@ -115,5 +93,61 @@ def detokenize_incrementally(
|
|||||||
if current_sub_text:
|
if current_sub_text:
|
||||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||||
sub_texts.append(sub_text)
|
sub_texts.append(sub_text)
|
||||||
output_text = " ".join(sub_texts)
|
return " ".join(sub_texts)
|
||||||
return new_token, output_text
|
|
||||||
|
|
||||||
|
# Based on
|
||||||
|
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||||
|
# under Apache 2.0 license
|
||||||
|
def detokenize_incrementally(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
all_input_ids: List[int],
|
||||||
|
prev_tokens: Optional[List[str]],
|
||||||
|
prefix_offset: int = 0,
|
||||||
|
read_offset: int = 0,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
) -> Tuple[List[str], str, int, int]:
|
||||||
|
new_token_id = all_input_ids[-1]
|
||||||
|
# This is the first iteration for this sequence
|
||||||
|
if prev_tokens is None:
|
||||||
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
|
all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
output_tokens = new_tokens
|
||||||
|
# 5 is an arbitrary value that should work for all
|
||||||
|
# tokenizers (bigger = more conservative).
|
||||||
|
# Subtract 1 extra to account for the generated token.
|
||||||
|
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||||
|
read_offset = max(len(output_tokens) - 1, 0)
|
||||||
|
else:
|
||||||
|
new_token = tokenizer.convert_ids_to_tokens(
|
||||||
|
new_token_id, skip_special_tokens=skip_special_tokens)
|
||||||
|
new_tokens = [new_token]
|
||||||
|
output_tokens = prev_tokens + new_tokens
|
||||||
|
|
||||||
|
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||||
|
# the decode which decide to add a space or not depending on the
|
||||||
|
# surrounding ids.
|
||||||
|
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||||
|
prefix_text = tokenizer.convert_tokens_to_string(
|
||||||
|
output_tokens[prefix_offset:read_offset])
|
||||||
|
new_text = tokenizer.convert_tokens_to_string(
|
||||||
|
output_tokens[prefix_offset:])
|
||||||
|
else:
|
||||||
|
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||||||
|
tokenizer,
|
||||||
|
output_tokens[prefix_offset:read_offset],
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
new_text = _convert_tokens_to_string_with_added_encoders(
|
||||||
|
tokenizer,
|
||||||
|
output_tokens[prefix_offset:],
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||||
|
# from byte fallback tokenization.
|
||||||
|
# If it's in the middle, it's probably a real invalid id generated
|
||||||
|
# by the model
|
||||||
|
new_text = new_text[len(prefix_text):]
|
||||||
|
return new_tokens, new_text, read_offset, len(output_tokens)
|
||||||
|
else:
|
||||||
|
return new_tokens, "", prefix_offset, read_offset
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user