[Core] Reduce unnecessary compute when logprobs=None (#6532)

This commit is contained in:
Peng Guanwen 2024-07-30 00:47:31 +08:00 committed by GitHub
parent 766435e660
commit db9e5708a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 133 additions and 78 deletions

View File

@ -14,7 +14,7 @@ MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF ["float"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size @pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False]) @pytest.mark.parametrize("detokenize", [True, False])
def test_get_prompt_logprobs( def test_get_prompt_logprobs(
hf_runner, hf_runner,
@ -63,7 +63,10 @@ def test_get_prompt_logprobs(
assert result.outputs[0].logprobs is not None assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs: for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs # If the output token is not included in the top X
# logprob, it can return 1 more data
assert (len(logprobs) == num_top_logprobs
or len(logprobs) == num_top_logprobs + 1)
output_text = result.outputs[0].text output_text = result.outputs[0].text
output_string_from_most_likely_tokens_lst: List[str] = [] output_string_from_most_likely_tokens_lst: List[str] = []
for top_logprobs in result.outputs[0].logprobs: for top_logprobs in result.outputs[0].logprobs:
@ -135,3 +138,35 @@ def test_max_logprobs():
bad_sampling_params = SamplingParams(logprobs=2) bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError): with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params) runner.generate(["Hello world"], sampling_params=bad_sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("detokenize", [True, False])
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
detokenize: bool, example_prompts):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
max_tokens = 5
with vllm_runner(
model,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
) as vllm_model:
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
logprobs=None,
temperature=0.0,
detokenize=detokenize)
results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none)
for i in range(len(results_logprobs_none)):
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None

View File

@ -1,5 +1,6 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
from math import inf
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
@ -774,8 +775,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from. # The next token ids to get the logprob value from.
next_token_ids: List[int] = [] next_token_ids: List[int] = []
# The largest requested number of logprobs. We find logprobs as many as the # The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API. # largest num logprobs in this API. If every logprobs is None, it will be
largest_num_logprobs = 1 # set to -1.
largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False
# Select indices to compute logprob from, ranks of token ids, and the top # Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs. # k token ids from logprobs.
@ -808,6 +812,8 @@ def _get_logprobs(
largest_num_logprobs = max(largest_num_logprobs, largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs) sampling_params.logprobs)
use_beam_search = use_beam_search or sampling_params.use_beam_search
assert len(next_token_ids) == len(query_indices) assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0: if len(query_indices) == 0:
@ -815,8 +821,15 @@ def _get_logprobs(
empty_prompt_logprob: Optional[PromptLogprobs] = None empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob] return [empty_prompt_logprob], [empty_sampled_logprob]
selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can # (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled. # contain duplicates if beam search is enabled.
@ -830,20 +843,18 @@ def _get_logprobs(
) )
assert selected_logprobs.shape[0] == ranks.shape[0] assert selected_logprobs.shape[0] == ranks.shape[0]
# We need to compute top k only if there exists logprobs > 0.
if largest_num_logprobs > 0:
# Logprobs of topk tokens for a batch of sequence groups. # Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch). # (num_query_tokens_across_batch).
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs, top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs, largest_num_logprobs,
dim=-1) dim=-1)
else: top_logprobs = top_logprobs.to('cpu')
top_logprobs, top_token_ids = None, None top_token_ids = top_token_ids.to('cpu')
selected_logprobs = selected_logprobs.to('cpu') selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu') ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
# Find prompt/sample logprobs. # Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
@ -940,12 +951,18 @@ def _get_sampled_logprob_if_needed(
): ):
"""Compute the sample logprob if needed.""" """Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs or 0 num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = [] sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample: if seq_group.do_sample:
assert len(next_token_ids) > 0 assert len(next_token_ids) > 0
if num_logprobs is None and not use_beam_search:
for next_token_id in next_token_ids:
# Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(inf)})
else:
# Pre-select items from tensor. tolist() is faster than repetitive # Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls. # `.item()` calls.
selected_logprob_items = selected_logprobs[ selected_logprob_items = selected_logprobs[
@ -953,25 +970,26 @@ def _get_sampled_logprob_if_needed(
len(next_token_ids)].tolist() len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist() len(next_token_ids)].tolist()
for idx, (next_token_id, for idx, (next_token_id, parent_id) in enumerate(
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)): zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token. # Get the logprob of a sampled token.
sampled_logprobs_dict = { sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx]) next_token_id:
(selected_logprob_items[idx], rank_items[idx])
} }
if num_logprobs is not None and num_logprobs > 0:
# Get top K logprobs. # Get top K logprobs.
if num_logprobs > 0:
top_ids = top_token_ids[top_logprob_idx + top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist() parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx + top_probs = top_logprobs[
parent_id, :num_logprobs].tolist() top_logprob_idx + parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~ # Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank. # num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1) top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({ sampled_logprobs_dict.update({
top_id: (top_prob, rank) top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs, for top_id, top_prob, rank in zip(
top_ranks) top_ids, top_probs, top_ranks)
}) })
sampled_logprobs.append({ sampled_logprobs.append({

View File

@ -29,7 +29,7 @@ class CompletionOutput:
index: int index: int
text: str text: str
token_ids: Tuple[int, ...] token_ids: Tuple[int, ...]
cumulative_logprob: float cumulative_logprob: Optional[float]
logprobs: Optional[SampleLogprobs] logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None stop_reason: Union[int, str, None] = None
@ -124,10 +124,11 @@ class RequestOutput:
include_logprobs = seq_group.sampling_params.logprobs is not None include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [ outputs = [
CompletionOutput(seqs.index(seq), CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length), seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(), seq.get_output_token_ids(),
seq.get_cumulative_logprob(), seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None, seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status), SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs seq.stop_reason) for seq in top_n_seqs

View File

@ -92,11 +92,12 @@ class SamplingParams:
min_tokens: Minimum number of tokens to generate per output sequence min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token. logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return When set to None, no probability is returned. If set to a non-None
result includes the log probabilities on the `logprobs` most likely value, the result includes the log probabilities of the specified
tokens, as well the chosen tokens. The API will always return the number of most likely tokens, as well as the chosen tokens.
log probability of the sampled token, so there may be up to Note that the implementation follows the OpenAI API: The API will
`logprobs+1` elements in the response. always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token. prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True. detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output. skip_special_tokens: Whether to skip special tokens in the output.
@ -168,8 +169,8 @@ class SamplingParams:
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.min_tokens = min_tokens self.min_tokens = min_tokens
self.logprobs = logprobs self.logprobs = 1 if logprobs is True else logprobs
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now. # NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does # It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs. # not support returning only a list of token IDs.