Delay GPU->CPU sync in sampling (#1337)
This commit is contained in:
parent
aa9af07cac
commit
15f5632365
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from xformers.ops import AttentionBias
|
from xformers.ops import AttentionBias
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
|
||||||
@ -29,6 +29,8 @@ class InputMetadata:
|
|||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
max_context_len: int,
|
max_context_len: int,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
|
selected_token_indices: torch.Tensor,
|
||||||
|
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
@ -38,6 +40,8 @@ class InputMetadata:
|
|||||||
self.context_lens = context_lens
|
self.context_lens = context_lens
|
||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
self.selected_token_indices = selected_token_indices
|
||||||
|
self.categorized_sample_indices = categorized_sample_indices
|
||||||
|
|
||||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||||
self.to_cache = None
|
self.to_cache = None
|
||||||
@ -72,13 +76,16 @@ class InputMetadata:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
# Print only useful metadata.
|
# Print only useful metadata.
|
||||||
return (f'InputMetadata('
|
return (
|
||||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
f'InputMetadata('
|
||||||
f'num_prompts={self.num_prompts}, '
|
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||||
f'prompt_lens={self.prompt_lens}, '
|
f'num_prompts={self.num_prompts}, '
|
||||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
f'prompt_lens={self.prompt_lens}, '
|
||||||
f'context_lens={self.context_lens}, '
|
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||||
f'max_context_len={self.max_context_len}, '
|
f'context_lens={self.context_lens}, '
|
||||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
f'max_context_len={self.max_context_len}), '
|
||||||
f'block_tables={self.block_tables}, '
|
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||||
f'slot_mapping={self.slot_mapping})')
|
f'block_tables={self.block_tables}, '
|
||||||
|
f'selected_token_indices={self.selected_token_indices}, '
|
||||||
|
f'categorized_sample_indices={self.categorized_sample_indices}, '
|
||||||
|
f'slot_mapping={self.slot_mapping})')
|
||||||
|
|||||||
@ -109,29 +109,8 @@ def _prune_hidden_states(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
selected_token_indices: List[int] = []
|
|
||||||
start_idx = 0
|
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
if i < input_metadata.num_prompts:
|
|
||||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(start_idx, start_idx + prompt_len - 1))
|
|
||||||
selected_token_indices.append(start_idx + prompt_len - 1)
|
|
||||||
start_idx += input_metadata.max_prompt_len
|
|
||||||
else:
|
|
||||||
num_seqs = len(seq_ids)
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(start_idx, start_idx + num_seqs))
|
|
||||||
start_idx += num_seqs
|
|
||||||
|
|
||||||
selected_token_indices = torch.tensor(selected_token_indices,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=hidden_states.device)
|
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
return hidden_states.index_select(0, selected_token_indices)
|
return hidden_states.index_select(0, input_metadata.selected_token_indices)
|
||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
def _get_penalties(
|
||||||
@ -426,21 +405,11 @@ def _sample(
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices = input_metadata.categorized_sample_indices
|
||||||
start_idx = 0
|
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
_, sampling_params = seq_group
|
||||||
sampling_type = sampling_params.sampling_type
|
sampling_type = sampling_params.sampling_type
|
||||||
if (i < input_metadata.num_prompts
|
|
||||||
and sampling_params.prompt_logprobs is not None):
|
|
||||||
# NOTE: prompt token positions do not need sample, skip
|
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
|
||||||
start_idx += prompt_len - 1
|
|
||||||
categorized_seq_group_ids[sampling_type].append(i)
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
num_seqs = len(seq_ids)
|
|
||||||
categorized_sample_indices[sampling_type].extend(
|
|
||||||
range(start_idx, start_idx + num_seqs))
|
|
||||||
start_idx += num_seqs
|
|
||||||
|
|
||||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|||||||
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
||||||
@ -161,6 +161,10 @@ class Worker:
|
|||||||
input_tokens: List[List[int]] = []
|
input_tokens: List[List[int]] = []
|
||||||
input_positions: List[List[int]] = []
|
input_positions: List[List[int]] = []
|
||||||
slot_mapping: List[List[int]] = []
|
slot_mapping: List[List[int]] = []
|
||||||
|
selected_token_indices: List[int] = []
|
||||||
|
selected_token_start_idx = 0
|
||||||
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
|
categorized_sample_indices_start_idx = 0
|
||||||
|
|
||||||
# Add prompt tokens.
|
# Add prompt tokens.
|
||||||
prompt_lens: List[int] = []
|
prompt_lens: List[int] = []
|
||||||
@ -180,6 +184,14 @@ class Worker:
|
|||||||
prompt_len = len(prompt_tokens)
|
prompt_len = len(prompt_tokens)
|
||||||
prompt_lens.append(prompt_len)
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
# NOTE: prompt token positions do not need sample, skip
|
||||||
|
categorized_sample_indices_start_idx += prompt_len - 1
|
||||||
|
|
||||||
|
categorized_sample_indices[sampling_params.sampling_type].append(
|
||||||
|
categorized_sample_indices_start_idx)
|
||||||
|
categorized_sample_indices_start_idx += 1
|
||||||
|
|
||||||
input_tokens.append(prompt_tokens)
|
input_tokens.append(prompt_tokens)
|
||||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||||
# is always the first token in the sequence.
|
# is always the first token in the sequence.
|
||||||
@ -205,14 +217,37 @@ class Worker:
|
|||||||
max_num_blocks_per_seq = 0
|
max_num_blocks_per_seq = 0
|
||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
generation_block_tables: List[List[int]] = []
|
generation_block_tables: List[List[int]] = []
|
||||||
|
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
if seq_group_metadata.is_prompt:
|
if seq_group_metadata.is_prompt:
|
||||||
|
# We need to do this in this loop as we need to know max_seq_len
|
||||||
|
assert len(
|
||||||
|
seq_ids) == 1, "Prompt input should have only one seq."
|
||||||
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(selected_token_start_idx,
|
||||||
|
selected_token_start_idx + prompt_len - 1))
|
||||||
|
selected_token_indices.append(selected_token_start_idx +
|
||||||
|
prompt_len - 1)
|
||||||
|
selected_token_start_idx += max_seq_len
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
|
|
||||||
|
num_seqs = len(seq_ids)
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(selected_token_start_idx,
|
||||||
|
selected_token_start_idx + num_seqs))
|
||||||
|
selected_token_start_idx += num_seqs
|
||||||
|
|
||||||
|
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||||
|
range(categorized_sample_indices_start_idx,
|
||||||
|
categorized_sample_indices_start_idx + num_seqs))
|
||||||
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
generation_token = seq_data.get_last_token_id()
|
generation_token = seq_data.get_last_token_id()
|
||||||
@ -242,7 +277,6 @@ class Worker:
|
|||||||
block_table = block_table[-sliding_window_blocks:]
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
generation_block_tables.append(block_table)
|
generation_block_tables.append(block_table)
|
||||||
|
|
||||||
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
|
||||||
padded_input_tokens = [
|
padded_input_tokens = [
|
||||||
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
|
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
|
||||||
]
|
]
|
||||||
@ -272,6 +306,13 @@ class Worker:
|
|||||||
context_lens_tensor = torch.tensor(context_lens,
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
|
selected_token_indices = torch.tensor(selected_token_indices,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
categorized_sample_indices = {
|
||||||
|
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
||||||
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
|
}
|
||||||
block_tables_tensor = torch.tensor(padded_block_tables,
|
block_tables_tensor = torch.tensor(padded_block_tables,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
@ -288,6 +329,8 @@ class Worker:
|
|||||||
context_lens=context_lens_tensor,
|
context_lens=context_lens_tensor,
|
||||||
max_context_len=max_context_len,
|
max_context_len=max_context_len,
|
||||||
block_tables=block_tables_tensor,
|
block_tables=block_tables_tensor,
|
||||||
|
selected_token_indices=selected_token_indices,
|
||||||
|
categorized_sample_indices=categorized_sample_indices,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
)
|
)
|
||||||
return tokens_tensor, positions_tensor, input_metadata
|
return tokens_tensor, positions_tensor, input_metadata
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user