[Core] Reduce unnecessary compute when logprobs=None (#6532)
This commit is contained in:
parent
766435e660
commit
db9e5708a9
@ -14,7 +14,7 @@ MODELS = ["facebook/opt-125m"]
|
|||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype",
|
||||||
["float"]) # needed for comparing logprobs with HF
|
["float"]) # needed for comparing logprobs with HF
|
||||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||||
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
|
||||||
@pytest.mark.parametrize("detokenize", [True, False])
|
@pytest.mark.parametrize("detokenize", [True, False])
|
||||||
def test_get_prompt_logprobs(
|
def test_get_prompt_logprobs(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
@ -63,7 +63,10 @@ def test_get_prompt_logprobs(
|
|||||||
assert result.outputs[0].logprobs is not None
|
assert result.outputs[0].logprobs is not None
|
||||||
assert len(result.outputs[0].logprobs) == max_tokens
|
assert len(result.outputs[0].logprobs) == max_tokens
|
||||||
for logprobs in result.outputs[0].logprobs:
|
for logprobs in result.outputs[0].logprobs:
|
||||||
assert len(logprobs) == num_top_logprobs
|
# If the output token is not included in the top X
|
||||||
|
# logprob, it can return 1 more data
|
||||||
|
assert (len(logprobs) == num_top_logprobs
|
||||||
|
or len(logprobs) == num_top_logprobs + 1)
|
||||||
output_text = result.outputs[0].text
|
output_text = result.outputs[0].text
|
||||||
output_string_from_most_likely_tokens_lst: List[str] = []
|
output_string_from_most_likely_tokens_lst: List[str] = []
|
||||||
for top_logprobs in result.outputs[0].logprobs:
|
for top_logprobs in result.outputs[0].logprobs:
|
||||||
@ -135,3 +138,35 @@ def test_max_logprobs():
|
|||||||
bad_sampling_params = SamplingParams(logprobs=2)
|
bad_sampling_params = SamplingParams(logprobs=2)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||||
|
@pytest.mark.parametrize("detokenize", [True, False])
|
||||||
|
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
|
||||||
|
detokenize: bool, example_prompts):
|
||||||
|
max_num_seqs = 256
|
||||||
|
enable_chunked_prefill = False
|
||||||
|
max_num_batched_tokens = None
|
||||||
|
if chunked_prefill_token_size != -1:
|
||||||
|
enable_chunked_prefill = True
|
||||||
|
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||||
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
) as vllm_model:
|
||||||
|
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
|
||||||
|
logprobs=None,
|
||||||
|
temperature=0.0,
|
||||||
|
detokenize=detokenize)
|
||||||
|
results_logprobs_none = vllm_model.model.generate(
|
||||||
|
example_prompts, sampling_params=sampling_params_logprobs_none)
|
||||||
|
|
||||||
|
for i in range(len(results_logprobs_none)):
|
||||||
|
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||||
|
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
import itertools
|
import itertools
|
||||||
|
from math import inf
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -774,8 +775,11 @@ def _get_logprobs(
|
|||||||
# The next token ids to get the logprob value from.
|
# The next token ids to get the logprob value from.
|
||||||
next_token_ids: List[int] = []
|
next_token_ids: List[int] = []
|
||||||
# The largest requested number of logprobs. We find logprobs as many as the
|
# The largest requested number of logprobs. We find logprobs as many as the
|
||||||
# largest num logprobs in this API.
|
# largest num logprobs in this API. If every logprobs is None, it will be
|
||||||
largest_num_logprobs = 1
|
# set to -1.
|
||||||
|
largest_num_logprobs = -1
|
||||||
|
# If beam search is enabled.
|
||||||
|
use_beam_search = False
|
||||||
|
|
||||||
# Select indices to compute logprob from, ranks of token ids, and the top
|
# Select indices to compute logprob from, ranks of token ids, and the top
|
||||||
# k token ids from logprobs.
|
# k token ids from logprobs.
|
||||||
@ -808,6 +812,8 @@ def _get_logprobs(
|
|||||||
largest_num_logprobs = max(largest_num_logprobs,
|
largest_num_logprobs = max(largest_num_logprobs,
|
||||||
sampling_params.logprobs)
|
sampling_params.logprobs)
|
||||||
|
|
||||||
|
use_beam_search = use_beam_search or sampling_params.use_beam_search
|
||||||
|
|
||||||
assert len(next_token_ids) == len(query_indices)
|
assert len(next_token_ids) == len(query_indices)
|
||||||
|
|
||||||
if len(query_indices) == 0:
|
if len(query_indices) == 0:
|
||||||
@ -815,35 +821,40 @@ def _get_logprobs(
|
|||||||
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
||||||
return [empty_prompt_logprob], [empty_sampled_logprob]
|
return [empty_prompt_logprob], [empty_sampled_logprob]
|
||||||
|
|
||||||
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
selected_logprobs, ranks = None, None
|
||||||
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
|
top_logprobs, top_token_ids = None, None
|
||||||
|
|
||||||
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
|
||||||
# contain duplicates if beam search is enabled.
|
# skip the whole logprob calculation.
|
||||||
selected_logprobs = logprobs[[
|
if largest_num_logprobs >= 0 or use_beam_search:
|
||||||
query_indices_gpu,
|
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
||||||
next_token_ids_gpu,
|
next_token_ids_gpu = torch.tensor(next_token_ids,
|
||||||
]]
|
device=logprobs.device)
|
||||||
ranks = _get_ranks(
|
|
||||||
logprobs[query_indices_gpu],
|
|
||||||
next_token_ids_gpu,
|
|
||||||
)
|
|
||||||
assert selected_logprobs.shape[0] == ranks.shape[0]
|
|
||||||
|
|
||||||
# Logprobs of topk tokens for a batch of sequence groups.
|
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
||||||
# (num_query_tokens_across_batch).
|
# contain duplicates if beam search is enabled.
|
||||||
if largest_num_logprobs > 0:
|
selected_logprobs = logprobs[[
|
||||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
query_indices_gpu,
|
||||||
largest_num_logprobs,
|
next_token_ids_gpu,
|
||||||
dim=-1)
|
]]
|
||||||
else:
|
ranks = _get_ranks(
|
||||||
top_logprobs, top_token_ids = None, None
|
logprobs[query_indices_gpu],
|
||||||
|
next_token_ids_gpu,
|
||||||
|
)
|
||||||
|
assert selected_logprobs.shape[0] == ranks.shape[0]
|
||||||
|
|
||||||
selected_logprobs = selected_logprobs.to('cpu')
|
# We need to compute top k only if there exists logprobs > 0.
|
||||||
ranks = ranks.to('cpu')
|
if largest_num_logprobs > 0:
|
||||||
if top_logprobs is not None and top_token_ids is not None:
|
# Logprobs of topk tokens for a batch of sequence groups.
|
||||||
top_logprobs = top_logprobs.to('cpu')
|
# (num_query_tokens_across_batch).
|
||||||
top_token_ids = top_token_ids.to('cpu')
|
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||||
|
largest_num_logprobs,
|
||||||
|
dim=-1)
|
||||||
|
top_logprobs = top_logprobs.to('cpu')
|
||||||
|
top_token_ids = top_token_ids.to('cpu')
|
||||||
|
|
||||||
|
selected_logprobs = selected_logprobs.to('cpu')
|
||||||
|
ranks = ranks.to('cpu')
|
||||||
|
|
||||||
# Find prompt/sample logprobs.
|
# Find prompt/sample logprobs.
|
||||||
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
||||||
@ -940,45 +951,52 @@ def _get_sampled_logprob_if_needed(
|
|||||||
):
|
):
|
||||||
"""Compute the sample logprob if needed."""
|
"""Compute the sample logprob if needed."""
|
||||||
seq_ids = seq_group.seq_ids
|
seq_ids = seq_group.seq_ids
|
||||||
num_logprobs = seq_group.sampling_params.logprobs or 0
|
num_logprobs = seq_group.sampling_params.logprobs
|
||||||
|
use_beam_search = seq_group.sampling_params.use_beam_search
|
||||||
sampled_logprobs: SampleLogprobs = []
|
sampled_logprobs: SampleLogprobs = []
|
||||||
next_token_ids, parent_seq_ids = sample_result
|
next_token_ids, parent_seq_ids = sample_result
|
||||||
|
|
||||||
if seq_group.do_sample:
|
if seq_group.do_sample:
|
||||||
assert len(next_token_ids) > 0
|
assert len(next_token_ids) > 0
|
||||||
# Pre-select items from tensor. tolist() is faster than repetitive
|
if num_logprobs is None and not use_beam_search:
|
||||||
# `.item()` calls.
|
for next_token_id in next_token_ids:
|
||||||
selected_logprob_items = selected_logprobs[
|
# Use a dummy logprob
|
||||||
selected_logprobs_idx:selected_logprobs_idx +
|
sampled_logprobs.append({next_token_id: Logprob(inf)})
|
||||||
len(next_token_ids)].tolist()
|
else:
|
||||||
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
# Pre-select items from tensor. tolist() is faster than repetitive
|
||||||
len(next_token_ids)].tolist()
|
# `.item()` calls.
|
||||||
for idx, (next_token_id,
|
selected_logprob_items = selected_logprobs[
|
||||||
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
|
selected_logprobs_idx:selected_logprobs_idx +
|
||||||
# Get the logprob of a sampled token.
|
len(next_token_ids)].tolist()
|
||||||
sampled_logprobs_dict = {
|
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
||||||
next_token_id: (selected_logprob_items[idx], rank_items[idx])
|
len(next_token_ids)].tolist()
|
||||||
}
|
for idx, (next_token_id, parent_id) in enumerate(
|
||||||
# Get top K logprobs.
|
zip(next_token_ids, parent_seq_ids)):
|
||||||
if num_logprobs > 0:
|
# Get the logprob of a sampled token.
|
||||||
top_ids = top_token_ids[top_logprob_idx +
|
sampled_logprobs_dict = {
|
||||||
parent_id, :num_logprobs].tolist()
|
next_token_id:
|
||||||
top_probs = top_logprobs[top_logprob_idx +
|
(selected_logprob_items[idx], rank_items[idx])
|
||||||
parent_id, :num_logprobs].tolist()
|
}
|
||||||
# Top K is already sorted by rank, so we can use 1 ~
|
if num_logprobs is not None and num_logprobs > 0:
|
||||||
# num_logprobs + 1 for rank.
|
# Get top K logprobs.
|
||||||
top_ranks = range(1, num_logprobs + 1)
|
top_ids = top_token_ids[top_logprob_idx +
|
||||||
sampled_logprobs_dict.update({
|
parent_id, :num_logprobs].tolist()
|
||||||
top_id: (top_prob, rank)
|
top_probs = top_logprobs[
|
||||||
for top_id, top_prob, rank in zip(top_ids, top_probs,
|
top_logprob_idx + parent_id, :num_logprobs].tolist()
|
||||||
top_ranks)
|
# Top K is already sorted by rank, so we can use 1 ~
|
||||||
})
|
# num_logprobs + 1 for rank.
|
||||||
|
top_ranks = range(1, num_logprobs + 1)
|
||||||
|
sampled_logprobs_dict.update({
|
||||||
|
top_id: (top_prob, rank)
|
||||||
|
for top_id, top_prob, rank in zip(
|
||||||
|
top_ids, top_probs, top_ranks)
|
||||||
|
})
|
||||||
|
|
||||||
sampled_logprobs.append({
|
sampled_logprobs.append({
|
||||||
token_id: Logprob(*logprob_and_rank)
|
token_id: Logprob(*logprob_and_rank)
|
||||||
for token_id, logprob_and_rank in
|
for token_id, logprob_and_rank in
|
||||||
sampled_logprobs_dict.items()
|
sampled_logprobs_dict.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
# NOTE: This part of code is not intuitive. `selected_logprobs` include
|
# NOTE: This part of code is not intuitive. `selected_logprobs` include
|
||||||
# logprobs for the current step, which has len(next_token_ids) tokens
|
# logprobs for the current step, which has len(next_token_ids) tokens
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class CompletionOutput:
|
|||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
token_ids: Tuple[int, ...]
|
token_ids: Tuple[int, ...]
|
||||||
cumulative_logprob: float
|
cumulative_logprob: Optional[float]
|
||||||
logprobs: Optional[SampleLogprobs]
|
logprobs: Optional[SampleLogprobs]
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Union[int, str, None] = None
|
stop_reason: Union[int, str, None] = None
|
||||||
@ -124,13 +124,14 @@ class RequestOutput:
|
|||||||
include_logprobs = seq_group.sampling_params.logprobs is not None
|
include_logprobs = seq_group.sampling_params.logprobs is not None
|
||||||
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
|
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
|
||||||
outputs = [
|
outputs = [
|
||||||
CompletionOutput(seqs.index(seq),
|
CompletionOutput(
|
||||||
seq.get_output_text_to_return(text_buffer_length),
|
seqs.index(seq),
|
||||||
seq.get_output_token_ids(),
|
seq.get_output_text_to_return(text_buffer_length),
|
||||||
seq.get_cumulative_logprob(),
|
seq.get_output_token_ids(),
|
||||||
seq.output_logprobs if include_logprobs else None,
|
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||||
SequenceStatus.get_finished_reason(seq.status),
|
seq.output_logprobs if include_logprobs else None,
|
||||||
seq.stop_reason) for seq in top_n_seqs
|
SequenceStatus.get_finished_reason(seq.status),
|
||||||
|
seq.stop_reason) for seq in top_n_seqs
|
||||||
]
|
]
|
||||||
|
|
||||||
# Every sequence in the sequence group should have the same prompt.
|
# Every sequence in the sequence group should have the same prompt.
|
||||||
|
|||||||
@ -92,11 +92,12 @@ class SamplingParams:
|
|||||||
min_tokens: Minimum number of tokens to generate per output sequence
|
min_tokens: Minimum number of tokens to generate per output sequence
|
||||||
before EOS or stop_token_ids can be generated
|
before EOS or stop_token_ids can be generated
|
||||||
logprobs: Number of log probabilities to return per output token.
|
logprobs: Number of log probabilities to return per output token.
|
||||||
Note that the implementation follows the OpenAI API: The return
|
When set to None, no probability is returned. If set to a non-None
|
||||||
result includes the log probabilities on the `logprobs` most likely
|
value, the result includes the log probabilities of the specified
|
||||||
tokens, as well the chosen tokens. The API will always return the
|
number of most likely tokens, as well as the chosen tokens.
|
||||||
log probability of the sampled token, so there may be up to
|
Note that the implementation follows the OpenAI API: The API will
|
||||||
`logprobs+1` elements in the response.
|
always return the log probability of the sampled token, so there
|
||||||
|
may be up to `logprobs+1` elements in the response.
|
||||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||||
detokenize: Whether to detokenize the output. Defaults to True.
|
detokenize: Whether to detokenize the output. Defaults to True.
|
||||||
skip_special_tokens: Whether to skip special tokens in the output.
|
skip_special_tokens: Whether to skip special tokens in the output.
|
||||||
@ -168,8 +169,8 @@ class SamplingParams:
|
|||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.min_tokens = min_tokens
|
self.min_tokens = min_tokens
|
||||||
self.logprobs = logprobs
|
self.logprobs = 1 if logprobs is True else logprobs
|
||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
|
||||||
# NOTE: This parameter is only exposed at the engine level for now.
|
# NOTE: This parameter is only exposed at the engine level for now.
|
||||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||||
# not support returning only a list of token IDs.
|
# not support returning only a list of token IDs.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user