[Core] Improve detokenization performance for prefill (#3469)
Co-authored-by: MeloYang <meloyang05@gmail.com>
This commit is contained in:
parent
cf2f084d56
commit
bfdb1ba5c3
@ -1,13 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from vllm.sequence import Sequence, Logprob, SamplingParams, SequenceGroup
|
||||||
|
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||||
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
|
|
||||||
TRUTH = [
|
TRUTH = [
|
||||||
"Hello here, this is a simple test", # noqa: E501
|
"Hello here, this is a simple test",
|
||||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
|
||||||
"我很感谢你的热情" # noqa: E501
|
"我很感谢你的热情"
|
||||||
]
|
]
|
||||||
TOKENIZERS = [
|
TOKENIZERS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
@ -24,12 +28,12 @@ TOKENIZERS = [
|
|||||||
|
|
||||||
|
|
||||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||||
skip_special_tokens: bool):
|
skip_special_tokens: bool, starting_index: int):
|
||||||
decoded_text = ""
|
decoded_text = ""
|
||||||
offset = 0
|
offset = 0
|
||||||
token_offset = 0
|
token_offset = 0
|
||||||
prev_tokens = None
|
prev_tokens = None
|
||||||
for i in range(len(all_input_ids)):
|
for i in range(starting_index, len(all_input_ids)):
|
||||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
all_input_ids[:i + 1],
|
all_input_ids[:i + 1],
|
||||||
@ -46,17 +50,152 @@ def _run_incremental_decode(tokenizer, all_input_ids,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("truth", TRUTH)
|
@pytest.mark.parametrize("truth", TRUTH)
|
||||||
|
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||||
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
def test_decode_streaming(tokenizer_id, truth, with_prompt,
|
||||||
|
skip_special_tokens):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
if with_prompt:
|
||||||
|
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
|
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
||||||
|
generated_input_ids = truth_tokens[len(truth) // 2:]
|
||||||
|
all_input_ids = prompt_input_ids + generated_input_ids
|
||||||
|
starting_index = len(prompt_input_ids)
|
||||||
|
prompt = tokenizer.decode(prompt_input_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
generated = truth[len(prompt):]
|
||||||
|
else:
|
||||||
|
generated = truth
|
||||||
|
starting_index = 0
|
||||||
|
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||||
if skip_special_tokens:
|
if skip_special_tokens:
|
||||||
all_input_ids = ([tokenizer.bos_token_id]
|
if tokenizer.bos_token_id is not None:
|
||||||
if tokenizer.bos_token_id is not None else
|
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
||||||
[]) + all_input_ids + [tokenizer.eos_token_id]
|
starting_index += 1
|
||||||
|
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
decoded_text = _run_incremental_decode(
|
decoded_text = _run_incremental_decode(
|
||||||
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
tokenizer,
|
||||||
|
all_input_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
starting_index=starting_index)
|
||||||
|
|
||||||
assert decoded_text == truth
|
assert decoded_text == generated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||||
|
init_kwargs = dict(
|
||||||
|
tokenizer_id=tokenizer_name,
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=100,
|
||||||
|
max_input_length=None,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer_group = get_tokenizer_group(
|
||||||
|
None,
|
||||||
|
**init_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Detokenizer(tokenizer_group)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="complete_sequence_token_ids")
|
||||||
|
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||||
|
tokenizer_name: str) -> List[int]:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
|
||||||
|
return complete_sequence_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def create_sequence(prompt_token_ids=None):
|
||||||
|
prompt_token_ids = prompt_token_ids or [1]
|
||||||
|
return Sequence(
|
||||||
|
seq_id=0,
|
||||||
|
prompt="<s>",
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_logprobs(
|
||||||
|
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
|
||||||
|
return [{
|
||||||
|
token_id: Logprob(logprob=0.0),
|
||||||
|
token_id + 1: Logprob(logprob=0.1)
|
||||||
|
} for token_id in complete_sequence_token_ids]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||||
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
|
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
||||||
|
def test_decode_sequence_logprobs(complete_sequence: str,
|
||||||
|
complete_sequence_token_ids: List[int],
|
||||||
|
detokenizer: Detokenizer,
|
||||||
|
skip_special_tokens: bool):
|
||||||
|
"""Verify Detokenizer decodes logprobs correctly."""
|
||||||
|
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||||
|
logprobs=2)
|
||||||
|
|
||||||
|
# Run sequentially.
|
||||||
|
seq = create_sequence()
|
||||||
|
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||||
|
sequential_logprobs_text_chosen_token = []
|
||||||
|
sequential_logprobs_text_other_token = []
|
||||||
|
for new_token, logprobs in zip(complete_sequence_token_ids,
|
||||||
|
dummy_logprobs):
|
||||||
|
seq.append_token_id(new_token, logprobs)
|
||||||
|
detokenizer.decode_sequence_inplace(seq, sampling_params)
|
||||||
|
sequential_logprobs_text_chosen_token.append(
|
||||||
|
seq.output_logprobs[-1][new_token].decoded_token)
|
||||||
|
sequential_logprobs_text_other_token.append(
|
||||||
|
seq.output_logprobs[-1][new_token + 1].decoded_token)
|
||||||
|
sequential_result = seq.output_text
|
||||||
|
|
||||||
|
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
||||||
|
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
||||||
|
|
||||||
|
if skip_special_tokens:
|
||||||
|
# Text for logprobs for the chosen token should be the same as the
|
||||||
|
# generated text. Note that this will only be true if we skip
|
||||||
|
# special tokens.
|
||||||
|
assert sequential_result == complete_sequence
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||||
|
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||||
|
@pytest.mark.parametrize("skip_special_tokens", [True])
|
||||||
|
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||||
|
complete_sequence_token_ids: List[int],
|
||||||
|
detokenizer: Detokenizer,
|
||||||
|
skip_special_tokens: bool):
|
||||||
|
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||||
|
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||||
|
prompt_logprobs=1)
|
||||||
|
|
||||||
|
# Run sequentially.
|
||||||
|
seq = create_sequence(complete_sequence_token_ids)
|
||||||
|
seq_group = SequenceGroup(request_id="1",
|
||||||
|
seqs=[seq],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
arrival_time=0.0)
|
||||||
|
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||||
|
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
|
||||||
|
decoded_prompt_logprobs = dummy_logprobs
|
||||||
|
|
||||||
|
if skip_special_tokens:
|
||||||
|
# Text for logprobs for the chosen token should be the same as the
|
||||||
|
# prompt text. Note that this will only be true if we skip
|
||||||
|
# special tokens.
|
||||||
|
assert complete_sequence == "".join([
|
||||||
|
logprobs[token_id].decoded_token for token_id, logprobs in zip(
|
||||||
|
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||||
|
])
|
||||||
|
assert complete_sequence != "".join([
|
||||||
|
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
|
||||||
|
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||||
|
])
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -15,11 +15,11 @@ from vllm.engine.ray_utils import initialize_ray_cluster
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
|
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
|
||||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||||
get_tokenizer_group)
|
get_tokenizer_group)
|
||||||
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -97,6 +97,7 @@ class LLMEngine:
|
|||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
self._init_tokenizer()
|
self._init_tokenizer()
|
||||||
|
self.detokenizer = Detokenizer(self.tokenizer)
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
|
|
||||||
self.model_executor = executor_class(model_config, cache_config,
|
self.model_executor = executor_class(model_config, cache_config,
|
||||||
@ -153,7 +154,7 @@ class LLMEngine:
|
|||||||
raise RuntimeError("LLMEngine should not be pickled!")
|
raise RuntimeError("LLMEngine should not be pickled!")
|
||||||
|
|
||||||
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||||
return self.tokenizer.get_lora_tokenizer()
|
return self.tokenizer.get_lora_tokenizer(None)
|
||||||
|
|
||||||
def get_tokenizer_for_seq(self,
|
def get_tokenizer_for_seq(self,
|
||||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||||
@ -370,13 +371,8 @@ class LLMEngine:
|
|||||||
# Process prompt logprobs
|
# Process prompt logprobs
|
||||||
prompt_logprobs = outputs.prompt_logprobs
|
prompt_logprobs = outputs.prompt_logprobs
|
||||||
if prompt_logprobs is not None:
|
if prompt_logprobs is not None:
|
||||||
# We can pick any sequence for the prompt.
|
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||||
seq = next(iter(seq_group.seqs_dict.values()))
|
seq_group, prompt_logprobs)
|
||||||
all_token_ids = seq.get_token_ids()
|
|
||||||
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
|
|
||||||
self._decode_logprobs(seq, seq_group.sampling_params,
|
|
||||||
prompt_logprobs_for_token,
|
|
||||||
all_token_ids[:i])
|
|
||||||
seq_group.prompt_logprobs = prompt_logprobs
|
seq_group.prompt_logprobs = prompt_logprobs
|
||||||
|
|
||||||
# Process samples
|
# Process samples
|
||||||
@ -420,7 +416,8 @@ class LLMEngine:
|
|||||||
child_seqs.append((parent, parent))
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
for seq, _ in child_seqs:
|
for seq, _ in child_seqs:
|
||||||
self._decode_sequence(seq, seq_group.sampling_params)
|
self.detokenizer.decode_sequence_inplace(seq,
|
||||||
|
seq_group.sampling_params)
|
||||||
self._check_stop(seq, seq_group.sampling_params)
|
self._check_stop(seq, seq_group.sampling_params)
|
||||||
|
|
||||||
# Non-beam search case
|
# Non-beam search case
|
||||||
@ -713,51 +710,6 @@ class LLMEngine:
|
|||||||
time_e2e_requests=time_e2e_requests,
|
time_e2e_requests=time_e2e_requests,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
|
|
||||||
logprobs: Dict[int, Logprob],
|
|
||||||
all_input_ids: List[int]) -> None:
|
|
||||||
if not logprobs:
|
|
||||||
return
|
|
||||||
for token_id, sample_logprob in logprobs.items():
|
|
||||||
if (sample_logprob.decoded_token is None and token_id != -1):
|
|
||||||
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
|
|
||||||
(_, new_text, prefix_offset,
|
|
||||||
read_offset) = detokenize_incrementally(
|
|
||||||
self.get_tokenizer_for_seq(seq),
|
|
||||||
all_input_ids=all_input_ids_with_logprob,
|
|
||||||
prev_tokens=seq.tokens,
|
|
||||||
prefix_offset=seq.prefix_offset,
|
|
||||||
read_offset=seq.read_offset,
|
|
||||||
skip_special_tokens=prms.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=prms.
|
|
||||||
spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
sample_logprob.decoded_token = new_text
|
|
||||||
|
|
||||||
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
|
|
||||||
"""Decodes the new token for a sequence."""
|
|
||||||
all_input_ids = seq.get_token_ids()
|
|
||||||
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
|
|
||||||
all_input_ids)
|
|
||||||
|
|
||||||
(new_tokens, new_output_text, prefix_offset,
|
|
||||||
read_offset) = detokenize_incrementally(
|
|
||||||
self.get_tokenizer_for_seq(seq),
|
|
||||||
all_input_ids=all_input_ids,
|
|
||||||
prev_tokens=seq.tokens,
|
|
||||||
prefix_offset=seq.prefix_offset,
|
|
||||||
read_offset=seq.read_offset,
|
|
||||||
skip_special_tokens=prms.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
if seq.tokens is None:
|
|
||||||
seq.tokens = new_tokens
|
|
||||||
else:
|
|
||||||
seq.tokens.extend(new_tokens)
|
|
||||||
seq.prefix_offset = prefix_offset
|
|
||||||
seq.read_offset = read_offset
|
|
||||||
seq.output_text += new_output_text
|
|
||||||
|
|
||||||
def _check_stop(self, seq: Sequence,
|
def _check_stop(self, seq: Sequence,
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
"""Stop the finished sequences."""
|
"""Stop the finished sequences."""
|
||||||
|
|||||||
155
vllm/transformers_utils/detokenizer.py
Normal file
155
vllm/transformers_utils/detokenizer.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from typing import List, Dict, Optional
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
from vllm.sequence import Sequence, Logprob, SequenceGroup, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||||
|
convert_prompt_ids_to_tokens)
|
||||||
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||||
|
BaseTokenizerGroup)
|
||||||
|
|
||||||
|
# Used eg. for marking rejected tokens in spec decoding.
|
||||||
|
INVALID_TOKEN_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
class Detokenizer:
|
||||||
|
"""Provides methods to decode the output of a model into text."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_group: BaseTokenizerGroup):
|
||||||
|
self.tokenizer_group = tokenizer_group
|
||||||
|
|
||||||
|
def get_tokenizer_for_seq(self,
|
||||||
|
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||||
|
"""Returns the HF tokenizer to use for a given sequence."""
|
||||||
|
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
|
def decode_prompt_logprobs_inplace(
|
||||||
|
self, seq_group: SequenceGroup,
|
||||||
|
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
|
||||||
|
"""Decodes the logprobs for the prompt of a sequence group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_group: The sequence group to decode.
|
||||||
|
prompt_logprobs: The logprobs to decode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The prompt logprobs with the decoded tokens.
|
||||||
|
"""
|
||||||
|
prms = seq_group.sampling_params
|
||||||
|
# We can pick any sequence for the prompt.
|
||||||
|
seq = next(iter(seq_group.seqs_dict.values()))
|
||||||
|
# Only prompt, without the generated token.
|
||||||
|
all_token_ids = seq.get_token_ids()
|
||||||
|
prompt_token_ids = all_token_ids[:-1]
|
||||||
|
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||||
|
prefix_offset = 0
|
||||||
|
read_offset = 0
|
||||||
|
next_iter_prefix_offset = 0
|
||||||
|
next_iter_read_offset = 0
|
||||||
|
next_iter_tokens = []
|
||||||
|
prev_tokens = None
|
||||||
|
|
||||||
|
for token_position, prompt_logprobs_for_token in enumerate(
|
||||||
|
prompt_logprobs):
|
||||||
|
if not prompt_logprobs_for_token:
|
||||||
|
continue
|
||||||
|
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||||
|
if (sample_logprob.decoded_token is None
|
||||||
|
and token_id != INVALID_TOKEN_ID):
|
||||||
|
prompt_token_ids_with_token = (
|
||||||
|
prompt_token_ids[:token_position] + [token_id])
|
||||||
|
(new_tokens, new_text, new_prefix_offset,
|
||||||
|
new_read_offset) = detokenize_incrementally(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
all_input_ids=prompt_token_ids_with_token,
|
||||||
|
prev_tokens=prev_tokens,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
skip_special_tokens=prms.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=prms.
|
||||||
|
spaces_between_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_logprob.decoded_token = new_text
|
||||||
|
|
||||||
|
# Use the offsets & prev tokens corresponding to
|
||||||
|
# real tokens to ensure detokenization is consistent
|
||||||
|
# actual with prompt.
|
||||||
|
if token_id == all_token_ids[token_position]:
|
||||||
|
next_iter_prefix_offset = new_prefix_offset
|
||||||
|
next_iter_read_offset = new_read_offset
|
||||||
|
next_iter_tokens = new_tokens
|
||||||
|
|
||||||
|
# Advance to the next token position.
|
||||||
|
prefix_offset = next_iter_prefix_offset
|
||||||
|
read_offset = next_iter_read_offset
|
||||||
|
if prev_tokens is None:
|
||||||
|
prev_tokens = next_iter_tokens
|
||||||
|
else:
|
||||||
|
prev_tokens.extend(next_iter_tokens)
|
||||||
|
|
||||||
|
def decode_sequence_inplace(self, seq: Sequence,
|
||||||
|
prms: SamplingParams) -> None:
|
||||||
|
"""Decodes the new token for a sequence. In-place operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: The sequence to decode.
|
||||||
|
prms: The sampling parameters used to generate the sequence.
|
||||||
|
"""
|
||||||
|
all_input_ids = seq.get_token_ids()
|
||||||
|
token_id_generated_this_iteration = all_input_ids[-1]
|
||||||
|
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||||
|
|
||||||
|
# Convert prompt token IDs to tokens if necessary.
|
||||||
|
# Do it here so that we don't have to repeat this
|
||||||
|
# computation for each logprob.
|
||||||
|
if seq.tokens is None:
|
||||||
|
(seq.tokens, seq.prefix_offset,
|
||||||
|
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt_ids=all_input_ids[:-1],
|
||||||
|
skip_special_tokens=prms.skip_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||||
|
read_offset) = detokenize_incrementally(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
prev_tokens=seq.tokens,
|
||||||
|
prefix_offset=seq.prefix_offset,
|
||||||
|
read_offset=seq.read_offset,
|
||||||
|
skip_special_tokens=prms.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode logprobs
|
||||||
|
logprobs = seq.output_logprobs[-1]
|
||||||
|
if logprobs:
|
||||||
|
previous_tokens = all_input_ids[:-1]
|
||||||
|
for token_id, sample_logprob in logprobs.items():
|
||||||
|
# If the token was generated this iteration,
|
||||||
|
# use the provided text.
|
||||||
|
if token_id == token_id_generated_this_iteration:
|
||||||
|
sample_logprob.decoded_token = new_decoded_token_text
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (sample_logprob.decoded_token is None
|
||||||
|
and token_id != INVALID_TOKEN_ID):
|
||||||
|
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||||
|
(_, new_text, _, _) = detokenize_incrementally(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
all_input_ids=all_input_ids_with_logprob,
|
||||||
|
prev_tokens=seq.tokens,
|
||||||
|
prefix_offset=seq.prefix_offset,
|
||||||
|
read_offset=seq.read_offset,
|
||||||
|
skip_special_tokens=prms.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=prms.
|
||||||
|
spaces_between_special_tokens,
|
||||||
|
)
|
||||||
|
sample_logprob.decoded_token = new_text
|
||||||
|
|
||||||
|
if seq.tokens is None:
|
||||||
|
seq.tokens = new_tokens
|
||||||
|
else:
|
||||||
|
seq.tokens.extend(new_tokens)
|
||||||
|
seq.prefix_offset = prefix_offset
|
||||||
|
seq.read_offset = read_offset
|
||||||
|
seq.output_text += new_decoded_token_text
|
||||||
@ -158,6 +158,34 @@ def _convert_tokens_to_string_with_added_encoders(
|
|||||||
return "".join(sub_texts)
|
return "".join(sub_texts)
|
||||||
|
|
||||||
|
|
||||||
|
# 5 is an arbitrary value that should work for all
|
||||||
|
# tokenizers (bigger = more conservative).
|
||||||
|
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
|
|
||||||
|
def convert_prompt_ids_to_tokens(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
prompt_ids: List[int],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
) -> Tuple[List[str], int, int]:
|
||||||
|
"""Converts the prompt ids to tokens and returns the tokens and offsets
|
||||||
|
for incremental detokenization.
|
||||||
|
|
||||||
|
Note that not all tokens are converted to strings. Only the tokens that
|
||||||
|
are necessary for incremental detokenization are converted to strings.
|
||||||
|
"""
|
||||||
|
# Offset a little more in case we have special tokens.
|
||||||
|
prefix_offset = max(
|
||||||
|
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
|
||||||
|
# We do not need to convert the whole prompt to tokens.
|
||||||
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
|
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
|
||||||
|
prefix_offset = max(
|
||||||
|
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||||
|
read_offset = len(new_tokens)
|
||||||
|
return new_tokens, prefix_offset, read_offset
|
||||||
|
|
||||||
|
|
||||||
# Based on
|
# Based on
|
||||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||||
# under Apache 2.0 license
|
# under Apache 2.0 license
|
||||||
@ -165,31 +193,53 @@ def detokenize_incrementally(
|
|||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
all_input_ids: List[int],
|
all_input_ids: List[int],
|
||||||
prev_tokens: Optional[List[str]],
|
prev_tokens: Optional[List[str]],
|
||||||
prefix_offset: int = 0,
|
prefix_offset: int,
|
||||||
read_offset: int = 0,
|
read_offset: int,
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
) -> Tuple[List[str], str, int, int]:
|
) -> Tuple[List[str], str, int, int]:
|
||||||
|
"""Detokenizes the input ids incrementally and returns the new tokens
|
||||||
|
and the new text.
|
||||||
|
|
||||||
|
If `prev_tokens` is None, this function will convert the input ids to
|
||||||
|
tokens and return the tokens and the new text. Otherwise, it will return the
|
||||||
|
new tokens and the new text.
|
||||||
|
|
||||||
|
This function will also return the new prefix offset and the new read
|
||||||
|
offset to be used in the next iteration.
|
||||||
|
|
||||||
|
The offsets are necessary to defeat cleanup algorithms in the decode which
|
||||||
|
decide to add a space or not depending on the surrounding ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: The tokenizer to use.
|
||||||
|
all_input_ids: The input ids. The last id is the new token id.
|
||||||
|
prev_tokens: The previous tokens. If None, this function will convert
|
||||||
|
the input ids to tokens and return the tokens and the new text.
|
||||||
|
prefix_offset: The prefix offset.
|
||||||
|
read_offset: The read offset.
|
||||||
|
skip_special_tokens: Whether to skip special tokens.
|
||||||
|
spaces_between_special_tokens: Whether to add spaces between special
|
||||||
|
tokens.
|
||||||
|
"""
|
||||||
new_token_id = all_input_ids[-1]
|
new_token_id = all_input_ids[-1]
|
||||||
# This is the first iteration for this sequence
|
# This is the first iteration for this sequence
|
||||||
if prev_tokens is None:
|
is_first_iter = prev_tokens is None
|
||||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
if is_first_iter:
|
||||||
all_input_ids, skip_special_tokens=skip_special_tokens)
|
(prev_tokens, prefix_offset,
|
||||||
output_tokens = new_tokens
|
read_offset) = convert_prompt_ids_to_tokens(
|
||||||
# 5 is an arbitrary value that should work for all
|
tokenizer,
|
||||||
# tokenizers (bigger = more conservative).
|
all_input_ids[:-1],
|
||||||
# Subtract 1 extra to account for the generated token.
|
skip_special_tokens=skip_special_tokens)
|
||||||
prefix_offset = max(len(output_tokens) - 6, 0)
|
|
||||||
# If the first new token is a special token, we can't skip 1 extra token
|
# Put new_token_id in a list so skip_special_tokens is respected
|
||||||
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
read_offset = max(len(output_tokens), 0)
|
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||||
else:
|
output_tokens = prev_tokens + new_tokens
|
||||||
read_offset = max(len(output_tokens) - 1, 0)
|
|
||||||
else:
|
# If this is the first iteration, return all tokens.
|
||||||
# Put new_token_id in a list so skip_special_tokens is respected
|
if is_first_iter:
|
||||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
new_tokens = output_tokens
|
||||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
|
||||||
output_tokens = prev_tokens + new_tokens
|
|
||||||
|
|
||||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||||
# the decode which decide to add a space or not depending on the
|
# the decode which decide to add a space or not depending on the
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user