Delay GPU->CPU sync in sampling (#1337)
This commit is contained in:
parent
aa9af07cac
commit
15f5632365
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
@ -29,6 +29,8 @@ class InputMetadata:
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
block_tables: torch.Tensor,
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
@ -38,6 +40,8 @@ class InputMetadata:
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
|
||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||
self.to_cache = None
|
||||
@ -72,13 +76,16 @@ class InputMetadata:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Print only useful metadata.
|
||||
return (f'InputMetadata('
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'max_context_len={self.max_context_len}, '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}, '
|
||||
f'slot_mapping={self.slot_mapping})')
|
||||
return (
|
||||
f'InputMetadata('
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}, '
|
||||
f'selected_token_indices={self.selected_token_indices}, '
|
||||
f'categorized_sample_indices={self.categorized_sample_indices}, '
|
||||
f'slot_mapping={self.slot_mapping})')
|
||||
|
||||
@ -109,29 +109,8 @@ def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
selected_token_indices: List[int] = []
|
||||
start_idx = 0
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(start_idx, start_idx + prompt_len - 1))
|
||||
selected_token_indices.append(start_idx + prompt_len - 1)
|
||||
start_idx += input_metadata.max_prompt_len
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
device=hidden_states.device)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
return hidden_states.index_select(0, selected_token_indices)
|
||||
return hidden_states.index_select(0, input_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
@ -426,21 +405,11 @@ def _sample(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
start_idx = 0
|
||||
categorized_sample_indices = input_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
_, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
start_idx += prompt_len - 1
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
num_seqs = len(seq_ids)
|
||||
categorized_sample_indices[sampling_type].extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
|
||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||
for sampling_type in SamplingType:
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
||||
@ -161,6 +161,10 @@ class Worker:
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
|
||||
# Add prompt tokens.
|
||||
prompt_lens: List[int] = []
|
||||
@ -180,6 +184,14 @@ class Worker:
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += prompt_len - 1
|
||||
|
||||
categorized_sample_indices[sampling_params.sampling_type].append(
|
||||
categorized_sample_indices_start_idx)
|
||||
categorized_sample_indices_start_idx += 1
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
@ -205,14 +217,37 @@ class Worker:
|
||||
max_num_blocks_per_seq = 0
|
||||
context_lens: List[int] = []
|
||||
generation_block_tables: List[List[int]] = []
|
||||
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if seq_group_metadata.is_prompt:
|
||||
# We need to do this in this loop as we need to know max_seq_len
|
||||
assert len(
|
||||
seq_ids) == 1, "Prompt input should have only one seq."
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + prompt_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_seq_len
|
||||
continue
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + num_seqs))
|
||||
selected_token_start_idx += num_seqs
|
||||
|
||||
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||
range(categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx + num_seqs))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
@ -242,7 +277,6 @@ class Worker:
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
||||
padded_input_tokens = [
|
||||
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
|
||||
]
|
||||
@ -272,6 +306,13 @@ class Worker:
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
categorized_sample_indices = {
|
||||
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
block_tables_tensor = torch.tensor(padded_block_tables,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
@ -288,6 +329,8 @@ class Worker:
|
||||
context_lens=context_lens_tensor,
|
||||
max_context_len=max_context_len,
|
||||
block_tables=block_tables_tensor,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
return tokens_tensor, positions_tensor, input_metadata
|
||||
|
||||
Loading…
Reference in New Issue
Block a user