[Core] Improve detokenization performance for prefill (#3469)
Co-authored-by: MeloYang <meloyang05@gmail.com>
This commit is contained in:
parent
cf2f084d56
commit
bfdb1ba5c3
@ -1,13 +1,17 @@
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from typing import List, Dict
|
||||
|
||||
from vllm.sequence import Sequence, Logprob, SamplingParams, SequenceGroup
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test", # noqa: E501
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501
|
||||
"我很感谢你的热情" # noqa: E501
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
|
||||
"我很感谢你的热情"
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
@ -24,12 +28,12 @@ TOKENIZERS = [
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool):
|
||||
skip_special_tokens: bool, starting_index: int):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(len(all_input_ids)):
|
||||
for i in range(starting_index, len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
@ -46,17 +50,152 @@ def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||
def test_decode_streaming(tokenizer_id, truth, with_prompt,
|
||||
skip_special_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
if with_prompt:
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
||||
generated_input_ids = truth_tokens[len(truth) // 2:]
|
||||
all_input_ids = prompt_input_ids + generated_input_ids
|
||||
starting_index = len(prompt_input_ids)
|
||||
prompt = tokenizer.decode(prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
generated = truth[len(prompt):]
|
||||
else:
|
||||
generated = truth
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
if skip_special_tokens:
|
||||
all_input_ids = ([tokenizer.bos_token_id]
|
||||
if tokenizer.bos_token_id is not None else
|
||||
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||
if tokenizer.bos_token_id is not None:
|
||||
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
||||
starting_index += 1
|
||||
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index)
|
||||
|
||||
assert decoded_text == truth
|
||||
assert decoded_text == generated
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
init_kwargs = dict(
|
||||
tokenizer_id=tokenizer_name,
|
||||
enable_lora=False,
|
||||
max_num_seqs=100,
|
||||
max_input_length=None,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
revision=None,
|
||||
)
|
||||
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
None,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
return Detokenizer(tokenizer_group)
|
||||
|
||||
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||
tokenizer_name: str) -> List[int]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
|
||||
return complete_sequence_token_ids
|
||||
|
||||
|
||||
def create_sequence(prompt_token_ids=None):
|
||||
prompt_token_ids = prompt_token_ids or [1]
|
||||
return Sequence(
|
||||
seq_id=0,
|
||||
prompt="<s>",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
|
||||
def create_dummy_logprobs(
|
||||
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
|
||||
return [{
|
||||
token_id: Logprob(logprob=0.0),
|
||||
token_id + 1: Logprob(logprob=0.1)
|
||||
} for token_id in complete_sequence_token_ids]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
||||
def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer,
|
||||
skip_special_tokens: bool):
|
||||
"""Verify Detokenizer decodes logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
logprobs=2)
|
||||
|
||||
# Run sequentially.
|
||||
seq = create_sequence()
|
||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||
sequential_logprobs_text_chosen_token = []
|
||||
sequential_logprobs_text_other_token = []
|
||||
for new_token, logprobs in zip(complete_sequence_token_ids,
|
||||
dummy_logprobs):
|
||||
seq.append_token_id(new_token, logprobs)
|
||||
detokenizer.decode_sequence_inplace(seq, sampling_params)
|
||||
sequential_logprobs_text_chosen_token.append(
|
||||
seq.output_logprobs[-1][new_token].decoded_token)
|
||||
sequential_logprobs_text_other_token.append(
|
||||
seq.output_logprobs[-1][new_token + 1].decoded_token)
|
||||
sequential_result = seq.output_text
|
||||
|
||||
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
||||
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
||||
|
||||
if skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# generated text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
assert sequential_result == complete_sequence
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True])
|
||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer,
|
||||
skip_special_tokens: bool):
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
prompt_logprobs=1)
|
||||
|
||||
# Run sequentially.
|
||||
seq = create_sequence(complete_sequence_token_ids)
|
||||
seq_group = SequenceGroup(request_id="1",
|
||||
seqs=[seq],
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=0.0)
|
||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
|
||||
decoded_prompt_logprobs = dummy_logprobs
|
||||
|
||||
if skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# prompt text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
assert complete_sequence == "".join([
|
||||
logprobs[token_id].decoded_token for token_id, logprobs in zip(
|
||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||
])
|
||||
assert complete_sequence != "".join([
|
||||
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
|
||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||
])
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -15,11 +15,11 @@ from vllm.engine.ray_utils import initialize_ray_cluster
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -97,6 +97,7 @@ class LLMEngine:
|
||||
self._verify_args()
|
||||
|
||||
self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
self.model_executor = executor_class(model_config, cache_config,
|
||||
@ -153,7 +154,7 @@ class LLMEngine:
|
||||
raise RuntimeError("LLMEngine should not be pickled!")
|
||||
|
||||
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer()
|
||||
return self.tokenizer.get_lora_tokenizer(None)
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
@ -370,13 +371,8 @@ class LLMEngine:
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None:
|
||||
# We can pick any sequence for the prompt.
|
||||
seq = next(iter(seq_group.seqs_dict.values()))
|
||||
all_token_ids = seq.get_token_ids()
|
||||
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
|
||||
self._decode_logprobs(seq, seq_group.sampling_params,
|
||||
prompt_logprobs_for_token,
|
||||
all_token_ids[:i])
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
|
||||
# Process samples
|
||||
@ -420,7 +416,8 @@ class LLMEngine:
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq, seq_group.sampling_params)
|
||||
self.detokenizer.decode_sequence_inplace(seq,
|
||||
seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
@ -713,51 +710,6 @@ class LLMEngine:
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
)
|
||||
|
||||
def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
|
||||
logprobs: Dict[int, Logprob],
|
||||
all_input_ids: List[int]) -> None:
|
||||
if not logprobs:
|
||||
return
|
||||
for token_id, sample_logprob in logprobs.items():
|
||||
if (sample_logprob.decoded_token is None and token_id != -1):
|
||||
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
|
||||
(_, new_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
self.get_tokenizer_for_seq(seq),
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
all_input_ids = seq.get_token_ids()
|
||||
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
|
||||
all_input_ids)
|
||||
|
||||
(new_tokens, new_output_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
self.get_tokenizer_for_seq(seq),
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||
)
|
||||
if seq.tokens is None:
|
||||
seq.tokens = new_tokens
|
||||
else:
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_output_text
|
||||
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences."""
|
||||
|
||||
155
vllm/transformers_utils/detokenizer.py
Normal file
155
vllm/transformers_utils/detokenizer.py
Normal file
@ -0,0 +1,155 @@
|
||||
from typing import List, Dict, Optional
|
||||
from transformers import PreTrainedTokenizer
|
||||
from vllm.sequence import Sequence, Logprob, SequenceGroup, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
convert_prompt_ids_to_tokens)
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
|
||||
# Used eg. for marking rejected tokens in spec decoding.
|
||||
INVALID_TOKEN_ID = -1
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
"""Provides methods to decode the output of a model into text."""
|
||||
|
||||
def __init__(self, tokenizer_group: BaseTokenizerGroup):
|
||||
self.tokenizer_group = tokenizer_group
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
"""Returns the HF tokenizer to use for a given sequence."""
|
||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
def decode_prompt_logprobs_inplace(
|
||||
self, seq_group: SequenceGroup,
|
||||
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
|
||||
"""Decodes the logprobs for the prompt of a sequence group.
|
||||
|
||||
Args:
|
||||
seq_group: The sequence group to decode.
|
||||
prompt_logprobs: The logprobs to decode.
|
||||
|
||||
Returns:
|
||||
The prompt logprobs with the decoded tokens.
|
||||
"""
|
||||
prms = seq_group.sampling_params
|
||||
# We can pick any sequence for the prompt.
|
||||
seq = next(iter(seq_group.seqs_dict.values()))
|
||||
# Only prompt, without the generated token.
|
||||
all_token_ids = seq.get_token_ids()
|
||||
prompt_token_ids = all_token_ids[:-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
next_iter_prefix_offset = 0
|
||||
next_iter_read_offset = 0
|
||||
next_iter_tokens = []
|
||||
prev_tokens = None
|
||||
|
||||
for token_position, prompt_logprobs_for_token in enumerate(
|
||||
prompt_logprobs):
|
||||
if not prompt_logprobs_for_token:
|
||||
continue
|
||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != INVALID_TOKEN_ID):
|
||||
prompt_token_ids_with_token = (
|
||||
prompt_token_ids[:token_position] + [token_id])
|
||||
(new_tokens, new_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=prompt_token_ids_with_token,
|
||||
prev_tokens=prev_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
# Use the offsets & prev tokens corresponding to
|
||||
# real tokens to ensure detokenization is consistent
|
||||
# actual with prompt.
|
||||
if token_id == all_token_ids[token_position]:
|
||||
next_iter_prefix_offset = new_prefix_offset
|
||||
next_iter_read_offset = new_read_offset
|
||||
next_iter_tokens = new_tokens
|
||||
|
||||
# Advance to the next token position.
|
||||
prefix_offset = next_iter_prefix_offset
|
||||
read_offset = next_iter_read_offset
|
||||
if prev_tokens is None:
|
||||
prev_tokens = next_iter_tokens
|
||||
else:
|
||||
prev_tokens.extend(next_iter_tokens)
|
||||
|
||||
def decode_sequence_inplace(self, seq: Sequence,
|
||||
prms: SamplingParams) -> None:
|
||||
"""Decodes the new token for a sequence. In-place operation.
|
||||
|
||||
Args:
|
||||
seq: The sequence to decode.
|
||||
prms: The sampling parameters used to generate the sequence.
|
||||
"""
|
||||
all_input_ids = seq.get_token_ids()
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
|
||||
# Convert prompt token IDs to tokens if necessary.
|
||||
# Do it here so that we don't have to repeat this
|
||||
# computation for each logprob.
|
||||
if seq.tokens is None:
|
||||
(seq.tokens, seq.prefix_offset,
|
||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=all_input_ids[:-1],
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
)
|
||||
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
# Decode logprobs
|
||||
logprobs = seq.output_logprobs[-1]
|
||||
if logprobs:
|
||||
previous_tokens = all_input_ids[:-1]
|
||||
for token_id, sample_logprob in logprobs.items():
|
||||
# If the token was generated this iteration,
|
||||
# use the provided text.
|
||||
if token_id == token_id_generated_this_iteration:
|
||||
sample_logprob.decoded_token = new_decoded_token_text
|
||||
continue
|
||||
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
if seq.tokens is None:
|
||||
seq.tokens = new_tokens
|
||||
else:
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_decoded_token_text
|
||||
@ -158,6 +158,34 @@ def _convert_tokens_to_string_with_added_encoders(
|
||||
return "".join(sub_texts)
|
||||
|
||||
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
|
||||
def convert_prompt_ids_to_tokens(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prompt_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
"""Converts the prompt ids to tokens and returns the tokens and offsets
|
||||
for incremental detokenization.
|
||||
|
||||
Note that not all tokens are converted to strings. Only the tokens that
|
||||
are necessary for incremental detokenization are converted to strings.
|
||||
"""
|
||||
# Offset a little more in case we have special tokens.
|
||||
prefix_offset = max(
|
||||
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
|
||||
# We do not need to convert the whole prompt to tokens.
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
|
||||
prefix_offset = max(
|
||||
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||
read_offset = len(new_tokens)
|
||||
return new_tokens, prefix_offset, read_offset
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
@ -165,31 +193,53 @@ def detokenize_incrementally(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
all_input_ids: List[int],
|
||||
prev_tokens: Optional[List[str]],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
prefix_offset: int,
|
||||
read_offset: int,
|
||||
skip_special_tokens: bool = False,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
) -> Tuple[List[str], str, int, int]:
|
||||
"""Detokenizes the input ids incrementally and returns the new tokens
|
||||
and the new text.
|
||||
|
||||
If `prev_tokens` is None, this function will convert the input ids to
|
||||
tokens and return the tokens and the new text. Otherwise, it will return the
|
||||
new tokens and the new text.
|
||||
|
||||
This function will also return the new prefix offset and the new read
|
||||
offset to be used in the next iteration.
|
||||
|
||||
The offsets are necessary to defeat cleanup algorithms in the decode which
|
||||
decide to add a space or not depending on the surrounding ids.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use.
|
||||
all_input_ids: The input ids. The last id is the new token id.
|
||||
prev_tokens: The previous tokens. If None, this function will convert
|
||||
the input ids to tokens and return the tokens and the new text.
|
||||
prefix_offset: The prefix offset.
|
||||
read_offset: The read offset.
|
||||
skip_special_tokens: Whether to skip special tokens.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens.
|
||||
"""
|
||||
new_token_id = all_input_ids[-1]
|
||||
# This is the first iteration for this sequence
|
||||
if prev_tokens is None:
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = new_tokens
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
# Subtract 1 extra to account for the generated token.
|
||||
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||
# If the first new token is a special token, we can't skip 1 extra token
|
||||
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
|
||||
read_offset = max(len(output_tokens), 0)
|
||||
else:
|
||||
read_offset = max(len(output_tokens) - 1, 0)
|
||||
else:
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
is_first_iter = prev_tokens is None
|
||||
if is_first_iter:
|
||||
(prev_tokens, prefix_offset,
|
||||
read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer,
|
||||
all_input_ids[:-1],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# If this is the first iteration, return all tokens.
|
||||
if is_first_iter:
|
||||
new_tokens = output_tokens
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||
# the decode which decide to add a space or not depending on the
|
||||
|
||||
Loading…
Reference in New Issue
Block a user