From 15f5632365a98fd43ea42e4948a995aa399e99b5 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 30 Oct 2023 09:01:34 -0700 Subject: [PATCH] Delay GPU->CPU sync in sampling (#1337) --- vllm/model_executor/input_metadata.py | 29 ++++++++++------- vllm/model_executor/layers/sampler.py | 37 ++------------------- vllm/worker/worker.py | 47 +++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 265f3d35..b3b5852e 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple import torch from xformers.ops import AttentionBias -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData @@ -29,6 +29,8 @@ class InputMetadata: context_lens: torch.Tensor, max_context_len: int, block_tables: torch.Tensor, + selected_token_indices: torch.Tensor, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], sliding_window: Optional[int] = None, ) -> None: self.seq_groups = seq_groups @@ -38,6 +40,8 @@ class InputMetadata: self.context_lens = context_lens self.max_context_len = max_context_len self.block_tables = block_tables + self.selected_token_indices = selected_token_indices + self.categorized_sample_indices = categorized_sample_indices self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.to_cache = None @@ -72,13 +76,16 @@ class InputMetadata: def __repr__(self) -> str: # Print only useful metadata. - return (f'InputMetadata(' - f'num_prompt_tokens={self.num_prompt_tokens}, ' - f'num_prompts={self.num_prompts}, ' - f'prompt_lens={self.prompt_lens}, ' - f'num_generation_tokens={self.num_generation_tokens}, ' - f'context_lens={self.context_lens}, ' - f'max_context_len={self.max_context_len}, ' - f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' - f'block_tables={self.block_tables}, ' - f'slot_mapping={self.slot_mapping})') + return ( + f'InputMetadata(' + f'num_prompt_tokens={self.num_prompt_tokens}, ' + f'num_prompts={self.num_prompts}, ' + f'prompt_lens={self.prompt_lens}, ' + f'num_generation_tokens={self.num_generation_tokens}, ' + f'context_lens={self.context_lens}, ' + f'max_context_len={self.max_context_len}), ' + f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' + f'block_tables={self.block_tables}, ' + f'selected_token_indices={self.selected_token_indices}, ' + f'categorized_sample_indices={self.categorized_sample_indices}, ' + f'slot_mapping={self.slot_mapping})') diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 7a4dc2fd..6a29f1af 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -109,29 +109,8 @@ def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: - selected_token_indices: List[int] = [] - start_idx = 0 - for i, seq_group in enumerate(input_metadata.seq_groups): - seq_ids, sampling_params = seq_group - if i < input_metadata.num_prompts: - assert len(seq_ids) == 1, "Prompt input should have only one seq." - prompt_len = input_metadata.prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(start_idx, start_idx + prompt_len - 1)) - selected_token_indices.append(start_idx + prompt_len - 1) - start_idx += input_metadata.max_prompt_len - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(start_idx, start_idx + num_seqs)) - start_idx += num_seqs - - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long, - device=hidden_states.device) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - return hidden_states.index_select(0, selected_token_indices) + return hidden_states.index_select(0, input_metadata.selected_token_indices) def _get_penalties( @@ -426,21 +405,11 @@ def _sample( input_metadata: InputMetadata, ) -> List[Tuple[List[int], List[int]]]: categorized_seq_group_ids = {t: [] for t in SamplingType} - categorized_sample_indices = {t: [] for t in SamplingType} - start_idx = 0 + categorized_sample_indices = input_metadata.categorized_sample_indices for i, seq_group in enumerate(input_metadata.seq_groups): - seq_ids, sampling_params = seq_group + _, sampling_params = seq_group sampling_type = sampling_params.sampling_type - if (i < input_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - # NOTE: prompt token positions do not need sample, skip - prompt_len = input_metadata.prompt_lens[i] - start_idx += prompt_len - 1 categorized_seq_group_ids[sampling_type].append(i) - num_seqs = len(seq_ids) - categorized_sample_indices[sampling_type].extend( - range(start_idx, start_idx + num_seqs)) - start_idx += num_seqs sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} for sampling_type in SamplingType: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1b1f116a..fd6faecc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes @@ -161,6 +161,10 @@ class Worker: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + selected_token_indices: List[int] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 # Add prompt tokens. prompt_lens: List[int] = [] @@ -180,6 +184,14 @@ class Worker: prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += prompt_len - 1 + + categorized_sample_indices[sampling_params.sampling_type].append( + categorized_sample_indices_start_idx) + categorized_sample_indices_start_idx += 1 + input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. @@ -205,14 +217,37 @@ class Worker: max_num_blocks_per_seq = 0 context_lens: List[int] = [] generation_block_tables: List[List[int]] = [] + max_seq_len = max(prompt_lens) if prompt_lens else 1 for seq_group_metadata in seq_group_metadata_list: if seq_group_metadata.is_prompt: + # We need to do this in this loop as we need to know max_seq_len + assert len( + seq_ids) == 1, "Prompt input should have only one seq." + sampling_params = seq_group_metadata.sampling_params + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + prompt_len - 1)) + selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += max_seq_len continue seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[sampling_params.sampling_type].extend( + range(categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices_start_idx += num_seqs + for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() @@ -242,7 +277,6 @@ class Worker: block_table = block_table[-sliding_window_blocks:] generation_block_tables.append(block_table) - max_seq_len = max(prompt_lens) if prompt_lens else 1 padded_input_tokens = [ _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens ] @@ -272,6 +306,13 @@ class Worker: context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device="cuda") + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long, + device="cuda") + categorized_sample_indices = { + t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") + for t, seq_ids in categorized_sample_indices.items() + } block_tables_tensor = torch.tensor(padded_block_tables, dtype=torch.int, device="cuda") @@ -288,6 +329,8 @@ class Worker: context_lens=context_lens_tensor, max_context_len=max_context_len, block_tables=block_tables_tensor, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, sliding_window=self.sliding_window, ) return tokens_tensor, positions_tensor, input_metadata