Delay GPU->CPU sync in sampling (#1337)

This commit is contained in:
Antoni Baum 2023-10-30 09:01:34 -07:00 committed by GitHub
parent aa9af07cac
commit 15f5632365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 47 deletions

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import torch
from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
@ -29,6 +29,8 @@ class InputMetadata:
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
sliding_window: Optional[int] = None,
) -> None:
self.seq_groups = seq_groups
@ -38,6 +40,8 @@ class InputMetadata:
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None
@ -72,13 +76,16 @@ class InputMetadata:
def __repr__(self) -> str:
# Print only useful metadata.
return (f'InputMetadata('
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}, '
f'slot_mapping={self.slot_mapping})')
return (
f'InputMetadata('
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}, '
f'selected_token_indices={self.selected_token_indices}, '
f'categorized_sample_indices={self.categorized_sample_indices}, '
f'slot_mapping={self.slot_mapping})')

View File

@ -109,29 +109,8 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
selected_token_indices: List[int] = []
start_idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
assert len(seq_ids) == 1, "Prompt input should have only one seq."
prompt_len = input_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(start_idx, start_idx + prompt_len - 1))
selected_token_indices.append(start_idx + prompt_len - 1)
start_idx += input_metadata.max_prompt_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device=hidden_states.device)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, selected_token_indices)
return hidden_states.index_select(0, input_metadata.selected_token_indices)
def _get_penalties(
@ -426,21 +405,11 @@ def _sample(
input_metadata: InputMetadata,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = {t: [] for t in SamplingType}
start_idx = 0
categorized_sample_indices = input_metadata.categorized_sample_indices
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
_, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need sample, skip
prompt_len = input_metadata.prompt_lens[i]
start_idx += prompt_len - 1
categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids)
categorized_sample_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType:

View File

@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
@ -161,6 +161,10 @@ class Worker:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
# Add prompt tokens.
prompt_lens: List[int] = []
@ -180,6 +184,14 @@ class Worker:
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
@ -205,14 +217,37 @@ class Worker:
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
max_seq_len = max(prompt_lens) if prompt_lens else 1
for seq_group_metadata in seq_group_metadata_list:
if seq_group_metadata.is_prompt:
# We need to do this in this loop as we need to know max_seq_len
assert len(
seq_ids) == 1, "Prompt input should have only one seq."
sampling_params = seq_group_metadata.sampling_params
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
@ -242,7 +277,6 @@ class Worker:
block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table)
max_seq_len = max(prompt_lens) if prompt_lens else 1
padded_input_tokens = [
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
]
@ -272,6 +306,13 @@ class Worker:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
for t, seq_ids in categorized_sample_indices.items()
}
block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int,
device="cuda")
@ -288,6 +329,8 @@ class Worker:
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
sliding_window=self.sliding_window,
)
return tokens_tensor, positions_tensor, input_metadata