[Core] Improve detokenization performance for prefill (#3469)

Co-authored-by: MeloYang <meloyang05@gmail.com>
This commit is contained in:
Antoni Baum 2024-03-22 13:44:12 -07:00 committed by GitHub
parent cf2f084d56
commit bfdb1ba5c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 385 additions and 89 deletions

View File

@ -1,13 +1,17 @@
import pytest
from transformers import AutoTokenizer
from typing import List, Dict
from vllm.sequence import Sequence, Logprob, SamplingParams, SequenceGroup
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer import detokenize_incrementally
from vllm.transformers_utils.detokenizer import Detokenizer
TRUTH = [
"Hello here, this is a simple test", # noqa: E501
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501
"我很感谢你的热情" # noqa: E501
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
"我很感谢你的热情"
]
TOKENIZERS = [
"facebook/opt-125m",
@ -24,12 +28,12 @@ TOKENIZERS = [
def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool):
skip_special_tokens: bool, starting_index: int):
decoded_text = ""
offset = 0
token_offset = 0
prev_tokens = None
for i in range(len(all_input_ids)):
for i in range(starting_index, len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally(
tokenizer,
all_input_ids[:i + 1],
@ -46,17 +50,152 @@ def _run_incremental_decode(tokenizer, all_input_ids,
@pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False))
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
def test_decode_streaming(tokenizer_id, truth, with_prompt,
skip_special_tokens):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
prompt_input_ids = truth_tokens[:len(truth) // 2]
generated_input_ids = truth_tokens[len(truth) // 2:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens)
generated = truth[len(prompt):]
else:
generated = truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
if skip_special_tokens:
all_input_ids = ([tokenizer.bos_token_id]
if tokenizer.bos_token_id is not None else
[]) + all_input_ids + [tokenizer.eos_token_id]
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
starting_index += 1
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
decoded_text = _run_incremental_decode(
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
tokenizer,
all_input_ids,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)
assert decoded_text == truth
assert decoded_text == generated
@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
init_kwargs = dict(
tokenizer_id=tokenizer_name,
enable_lora=False,
max_num_seqs=100,
max_input_length=None,
tokenizer_mode="auto",
trust_remote_code=False,
revision=None,
)
tokenizer_group = get_tokenizer_group(
None,
**init_kwargs,
)
return Detokenizer(tokenizer_group)
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer_name: str) -> List[int]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
return complete_sequence_token_ids
def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
prompt="<s>",
prompt_token_ids=prompt_token_ids,
block_size=16,
)
def create_dummy_logprobs(
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
return [{
token_id: Logprob(logprob=0.0),
token_id + 1: Logprob(logprob=0.1)
} for token_id in complete_sequence_token_ids]
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
logprobs=2)
# Run sequentially.
seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token = []
sequential_logprobs_text_other_token = []
for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs):
seq.append_token_id(new_token, logprobs)
detokenizer.decode_sequence_inplace(seq, sampling_params)
sequential_logprobs_text_chosen_token.append(
seq.output_logprobs[-1][new_token].decoded_token)
sequential_logprobs_text_other_token.append(
seq.output_logprobs[-1][new_token + 1].decoded_token)
sequential_result = seq.output_text
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token)
if skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
assert sequential_result == complete_sequence
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True])
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1)
# Run sequentially.
seq = create_sequence(complete_sequence_token_ids)
seq_group = SequenceGroup(request_id="1",
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
decoded_prompt_logprobs = dummy_logprobs
if skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that this will only be true if we skip
# special tokens.
assert complete_sequence == "".join([
logprobs[token_id].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])
assert complete_sequence != "".join([
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])

View File

@ -1,5 +1,5 @@
import time
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Iterable, List, Optional, Tuple, Type, Union
from transformers import PreTrainedTokenizer
@ -15,11 +15,11 @@ from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
@ -97,6 +97,7 @@ class LLMEngine:
self._verify_args()
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter()
self.model_executor = executor_class(model_config, cache_config,
@ -153,7 +154,7 @@ class LLMEngine:
raise RuntimeError("LLMEngine should not be pickled!")
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer()
return self.tokenizer.get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
@ -370,13 +371,8 @@ class LLMEngine:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
all_token_ids = seq.get_token_ids()
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
self._decode_logprobs(seq, seq_group.sampling_params,
prompt_logprobs_for_token,
all_token_ids[:i])
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
@ -420,7 +416,8 @@ class LLMEngine:
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self._decode_sequence(seq, seq_group.sampling_params)
self.detokenizer.decode_sequence_inplace(seq,
seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
@ -713,51 +710,6 @@ class LLMEngine:
time_e2e_requests=time_e2e_requests,
)
def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
logprobs: Dict[int, Logprob],
all_input_ids: List[int]) -> None:
if not logprobs:
return
for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
(_, new_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
all_input_ids = seq.get_token_ids()
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
all_input_ids)
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_output_text
def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""

View File

@ -0,0 +1,155 @@
from typing import List, Dict, Optional
from transformers import PreTrainedTokenizer
from vllm.sequence import Sequence, Logprob, SequenceGroup, SamplingParams
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
convert_prompt_ids_to_tokens)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
# Used eg. for marking rejected tokens in spec decoding.
INVALID_TOKEN_ID = -1
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: BaseTokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def decode_prompt_logprobs_inplace(
self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
"""Decodes the logprobs for the prompt of a sequence group.
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
Returns:
The prompt logprobs with the decoded tokens.
"""
prms = seq_group.sampling_params
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens = []
prev_tokens = None
for token_position, prompt_logprobs_for_token in enumerate(
prompt_logprobs):
if not prompt_logprobs_for_token:
continue
for token_id, sample_logprob in prompt_logprobs_for_token.items():
if (sample_logprob.decoded_token is None
and token_id != INVALID_TOKEN_ID):
prompt_token_ids_with_token = (
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
# Use the offsets & prev tokens corresponding to
# real tokens to ensure detokenization is consistent
# actual with prompt.
if token_id == all_token_ids[token_position]:
next_iter_prefix_offset = new_prefix_offset
next_iter_read_offset = new_read_offset
next_iter_tokens = new_tokens
# Advance to the next token position.
prefix_offset = next_iter_prefix_offset
read_offset = next_iter_read_offset
if prev_tokens is None:
prev_tokens = next_iter_tokens
else:
prev_tokens.extend(next_iter_tokens)
def decode_sequence_inplace(self, seq: Sequence,
prms: SamplingParams) -> None:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text

View File

@ -158,6 +158,34 @@ def _convert_tokens_to_string_with_added_encoders(
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# Offset a little more in case we have special tokens.
prefix_offset = max(
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
# We do not need to convert the whole prompt to tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
prefix_offset = max(
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
read_offset = len(new_tokens)
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
@ -165,31 +193,53 @@ def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int = 0,
read_offset: int = 0,
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
if prev_tokens is None:
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids, skip_special_tokens=skip_special_tokens)
output_tokens = new_tokens
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token.
prefix_offset = max(len(output_tokens) - 6, 0)
# If the first new token is a special token, we can't skip 1 extra token
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
read_offset = max(len(output_tokens), 0)
else:
read_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the