From 55f8b0a5def22ed6b85d3b91b726a7573d54313b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 May 2023 23:39:12 -0700 Subject: [PATCH] Implement presence and frequency penalties (#95) --- cacheflow/core/scheduler.py | 23 ++--- cacheflow/frontend/fastapi_frontend.py | 2 +- cacheflow/frontend/simple_frontend.py | 5 +- cacheflow/model_executor/input_metadata.py | 10 +- cacheflow/model_executor/layers/sampler.py | 114 +++++++++++++++++++-- cacheflow/sampling_params.py | 35 +++++-- cacheflow/sequence.py | 69 ++++++++----- cacheflow/worker/worker.py | 35 ++++--- simple_server.py | 4 +- 9 files changed, 215 insertions(+), 82 deletions(-) diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 02e864e6..12f6157e 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -3,11 +3,12 @@ import time from typing import Dict, List, Optional, Tuple from cacheflow.core.block_manager import BlockSpaceManager -from cacheflow.logger import init_logger from cacheflow.core.policy import PolicyFactory +from cacheflow.logger import init_logger from cacheflow.sampling_params import SamplingParams -from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata, - SequenceOutputs, SequenceStatus) +from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupMetadata, SequenceOutputs, + SequenceStatus) logger = init_logger(__name__) @@ -246,27 +247,17 @@ class Scheduler: group_id = seq_group.group_id is_prompt = group_id in prompt_group_ids - input_tokens: Dict[int, List[int]] = {} - seq_logprobs: Dict[int, float] = {} + seq_data: Dict[int, List[SequenceData]] = {} block_tables: Dict[int, List[int]] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id + seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) - if is_prompt: - input_tokens[seq_id] = seq.get_token_ids() - else: - input_tokens[seq_id] = [seq.get_last_token_id()] - seq_logprobs[seq_id] = seq.cumulative_logprobs - # NOTE(woosuk): Sequences in the same group have the same - # sequence length - seq_len = seq.get_len() seq_group_metadata = SequenceGroupMetadata( group_id=group_id, is_prompt=is_prompt, - input_tokens=input_tokens, - context_len=seq_len, - seq_logprobs=seq_logprobs, + seq_data=seq_data, sampling_params=self.sampling_params[group_id], block_tables=block_tables, ) diff --git a/cacheflow/frontend/fastapi_frontend.py b/cacheflow/frontend/fastapi_frontend.py index 59e66a4c..c4712097 100644 --- a/cacheflow/frontend/fastapi_frontend.py +++ b/cacheflow/frontend/fastapi_frontend.py @@ -96,7 +96,7 @@ class FastAPIServer: seqs: List[Sequence] = [] for _ in range(sampling_params.n): seq_id = next(self.seq_counter) - seq = Sequence(seq_id, token_ids, block_size=self.block_size) + seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size) seqs.append(seq) arrival_time = time.time() diff --git a/cacheflow/frontend/simple_frontend.py b/cacheflow/frontend/simple_frontend.py index 9d65e4f0..19a8937a 100644 --- a/cacheflow/frontend/simple_frontend.py +++ b/cacheflow/frontend/simple_frontend.py @@ -35,10 +35,11 @@ class SimpleFrontend: sampling_params: SamplingParams, ) -> None: token_ids = self.tokenizer.encode(prompt) - self._add_query(token_ids, sampling_params) + self._add_query(prompt, token_ids, sampling_params) def _add_query( self, + prompt: str, token_ids: List[int], sampling_params: SamplingParams, arrival_time: Optional[float] = None, @@ -48,7 +49,7 @@ class SimpleFrontend: seqs: List[Sequence] = [] for _ in range(sampling_params.n): seq_id = next(self.seq_counter) - seq = Sequence(seq_id, token_ids, block_size=self.block_size) + seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size) seqs.append(seq) group_id = next(self.seq_group_counter) diff --git a/cacheflow/model_executor/input_metadata.py b/cacheflow/model_executor/input_metadata.py index 943524c9..465047aa 100644 --- a/cacheflow/model_executor/input_metadata.py +++ b/cacheflow/model_executor/input_metadata.py @@ -1,17 +1,18 @@ -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple import torch from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from cacheflow.sampling_params import SamplingParams +from cacheflow.sequence import SequenceData class InputMetadata: def __init__( self, - seq_groups: List[Tuple[List[int], SamplingParams]], - seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. + seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params). + seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData. prompt_lens: List[int], slot_mapping: torch.Tensor, context_lens: torch.Tensor, @@ -19,7 +20,7 @@ class InputMetadata: block_tables: torch.Tensor, ) -> None: self.seq_groups = seq_groups - self.seq_logprobs = seq_logprobs + self.seq_data = seq_data self.prompt_lens = prompt_lens self.slot_mapping = slot_mapping self.context_lens = context_lens @@ -39,6 +40,7 @@ class InputMetadata: assert context_lens.shape[0] == self.num_generation_tokens def __repr__(self) -> str: + # Print only useful metadata. return (f'InputMetadata(' f'num_valid_tokens={self.num_valid_tokens}, ' f'num_prompt_tokens={self.num_prompt_tokens}, ' diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index 7a7e53fb..838339c7 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -1,5 +1,6 @@ from typing import Dict, List, Tuple +import numpy as np import torch import torch.nn as nn @@ -31,6 +32,16 @@ class Sampler(nn.Module): # Remove paddings in vocab (if any). logits = logits[:, :self.vocab_size] + # Apply presence and frequency penalties. + output_tokens = _get_output_tokens(input_metadata) + assert len(output_tokens) == logits.shape[0] + presence_penalties, frequency_penalties = _get_penalties(input_metadata) + assert len(presence_penalties) == logits.shape[0] + assert len(frequency_penalties) == logits.shape[0] + logits = _apply_penalties( + logits, output_tokens, presence_penalties, frequency_penalties, + self.vocab_size) + # Apply temperature scaling. temperatures = _get_temperatures(input_metadata) assert len(temperatures) == logits.shape[0] @@ -43,16 +54,14 @@ class Sampler(nn.Module): # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities (before applying top-p). + # Compute the log probabilities (before applying top-p and top-k). logprobs = torch.log(probs) # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) assert len(top_ps) == len(top_ks) == probs.shape[0] if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks): - p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) - k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) - probs = _apply_top_p_top_k(probs, p, k) + probs = _apply_top_p_top_k(probs, top_ps, top_ks) # Sample the next tokens. return _sample(probs, logprobs, input_metadata) @@ -72,6 +81,93 @@ def _prune_hidden_states( return hidden_states[last_token_indicies] +def _get_penalties( + input_metadata: InputMetadata, +) -> Tuple[List[float], List[float]]: + # Collect the presence and frequency penalties. + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + if i < input_metadata.num_prompts: + # A prompt input. + presence_penalties.append(p) + frequency_penalties.append(f) + else: + # A generation token. + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + return presence_penalties, frequency_penalties + + +def _get_output_tokens( + input_metadata: InputMetadata, +) -> List[List[int]]: + output_tokens: List[List[int]] = [] + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, _ = seq_group + if i < input_metadata.num_prompts: + # A prompt input. + # NOTE: While the prompt input usually has no output tokens, + # it may have output tokens in the case of recomputation. + seq_id = seq_ids[0] + seq_data = input_metadata.seq_data[seq_id] + output_tokens.append(seq_data.output_token_ids) + else: + # A generation token. + for seq_id in seq_ids: + seq_data = input_metadata.seq_data[seq_id] + output_tokens.append(seq_data.output_token_ids) + return output_tokens + + +def _apply_penalties( + logits: torch.Tensor, + output_tokens: List[List[int]], + presence_penalties: List[float], + frequency_penalties: List[float], + vocab_size: int, +) -> torch.Tensor: + num_seqs = logits.shape[0] + # Collect the indices of sequences that have non-zero penalties. + indices = [] + for i in range(num_seqs): + if not output_tokens[i]: + continue + p = presence_penalties[i] + f = frequency_penalties[i] + if p == 0.0 and f == 0.0: + continue + indices.append(i) + + # Return early if all sequences have zero penalties. + if not indices: + return logits + + bin_counts = [] + for i in indices: + bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size)) + bin_counts = np.stack(bin_counts, axis=0) + bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype, + device=logits.device) + + frequency_penalties = [frequency_penalties[i] for i in indices] + frequency_penalties = torch.tensor( + frequency_penalties, dtype=logits.dtype, device=logits.device) + presence_penalties = [presence_penalties[i] for i in indices] + presence_penalties = torch.tensor( + presence_penalties, dtype=logits.dtype, device=logits.device) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts + presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype) + logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask + return logits + + def _get_temperatures( input_metadata: InputMetadata, ) -> List[float]: @@ -121,10 +217,11 @@ def _get_top_p_top_k( def _apply_top_p_top_k( probs: torch.Tensor, - p: torch.Tensor, - k: torch.Tensor, + top_ps: List[float], + top_ks: List[int], ) -> torch.Tensor: - # TODO(woosuk): Optimize. + p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) + k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) probs_sort, probs_idx = probs.sort(dim=-1, descending=True) # Apply top-p. @@ -286,7 +383,8 @@ def _sample( # Sample the next tokens. seq_logprobs = [ - input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids] + input_metadata.seq_data[seq_id].cumulative_logprobs + for seq_id in seq_ids] parent_seq_ids, next_token_ids = _sample_from_generation_tokens( seq_ids, prob, logprob, seq_logprobs, sampling_params) diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 589874eb..e8f67056 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -6,6 +6,8 @@ class SamplingParams: def __init__( self, n: int, + presence_penalty: float, + frequency_penalty: float, temperature: float, top_p: float, top_k: int, @@ -16,6 +18,12 @@ class SamplingParams: ) -> None: if n < 1: raise ValueError(f"n must be at least 1, got {n}.") + if not -2.0 <= presence_penalty <= 2.0: + raise ValueError( + f"presence_penalty must be in [-2, 2], got {presence_penalty}.") + if not -2.0 <= frequency_penalty <= 2.0: + raise ValueError( + f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.") if temperature < 0.0: raise ValueError( f"temperature must be non-negative, got {temperature}.") @@ -57,6 +65,8 @@ class SamplingParams: "top_k must be -1 when using greedy sampling.") self.n = n + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty self.temperature = temperature self.top_p = top_p self.top_k = top_k @@ -67,6 +77,8 @@ class SamplingParams: def __repr__(self) -> str: return (f"SamplingParams(n={self.n}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " f"temperature={self.temperature}, " f"top_p={self.top_p}, " f"top_k={self.top_k}," @@ -77,13 +89,18 @@ class SamplingParams: @classmethod def from_dict(cls, d: Dict) -> "SamplingParams": - return cls( - n=d.get("n", 1), - temperature=d.get("temperature", 1.0), - top_p=d.get("top_p", 1.0), - top_k=d.get("top_k", -1), - use_beam_search=d.get("use_beam_search", False), - stop_token_ids=set(d.get("stop_token_ids", set())), - max_num_steps=d.get("max_num_steps", 16), - num_logprobs=d.get("num_logprobs", 0), + sampling_params = cls( + n=d.pop("n", 1), + presence_penalty=d.pop("presence_penalty", 0.0), + frequency_penalty=d.pop("frequency_penalty", 0.0), + temperature=d.pop("temperature", 1.0), + top_p=d.pop("top_p", 1.0), + top_k=d.pop("top_k", -1), + use_beam_search=d.pop("use_beam_search", False), + stop_token_ids=set(d.pop("stop_token_ids", set())), + max_num_steps=d.pop("max_num_steps", 16), + num_logprobs=d.pop("num_logprobs", 0), ) + if d: + raise ValueError(f"Unrecognized keys in dict: {d.keys()}") + return sampling_params diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 62d3ef30..406ca728 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum): FINISHED = enum.auto() +class SequenceData: + + def __init__( + self, + prompt_token_ids: List[int], + ) -> None: + self.prompt_token_ids = prompt_token_ids + + self.output_token_ids: List[int] = [] + self.cumulative_logprobs = 0.0 + + def get_len(self) -> int: + return len(self.output_token_ids) + len(self.prompt_token_ids) + + def get_token_ids(self) -> List[int]: + return self.prompt_token_ids + self.output_token_ids + + def get_last_token_id(self) -> int: + if not self.output_token_ids: + return self.prompt_token_ids[-1] + return self.output_token_ids[-1] + + def __repr__(self) -> str: + return (f"SequenceData(" + f"prompt={self.prompt}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"output_token_ids={self.output_token_ids})") + + class Sequence: def __init__( self, seq_id: int, + prompt: str, prompt_token_ids: List[int], block_size: int, ) -> None: self.seq_id = seq_id + self.prompt = prompt self.block_size = block_size - self.prompt_len = len(prompt_token_ids) + self.data = SequenceData(prompt_token_ids) + self.output_logprobs: List[Dict[int, float]] = [] + self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. - self._append_tokens(prompt_token_ids) - + self._append_tokens_to_blocks(prompt_token_ids) self.status = SequenceStatus.WAITING - # Used for beam search. - self.output_logprobs: List[Dict[int, float]] = [] - self.cumulative_logprobs = 0.0 def _append_logical_block(self) -> None: block = LogicalTokenBlock( @@ -41,7 +70,7 @@ class Sequence: ) self.logical_token_blocks.append(block) - def _append_tokens(self, token_ids: List[int]) -> None: + def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: while token_ids: if not self.logical_token_blocks: self._append_logical_block() @@ -57,26 +86,24 @@ class Sequence: def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None: assert token_id in logprobs - self._append_tokens([token_id]) + self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.cumulative_logprobs += logprobs[token_id] + self.data.output_token_ids.append(token_id) + self.data.cumulative_logprobs += logprobs[token_id] def get_len(self) -> int: - return sum(block.num_tokens for block in self.logical_token_blocks) + return self.data.get_len() def get_token_ids(self) -> List[int]: - token_ids: List[int] = [] - for block in self.logical_token_blocks: - token_ids.extend(block.get_token_ids()) - return token_ids + return self.data.get_token_ids() def get_last_token_id(self) -> int: - return self.logical_token_blocks[-1].get_last_token_id() + return self.data.get_last_token_id() def fork(self, child_seq: 'Sequence') -> 'Sequence': child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) - child_seq.cumulative_logprobs = self.cumulative_logprobs + child_seq.data = copy.deepcopy(self.data) def __repr__(self) -> str: return (f'Sequence(seq_id={self.seq_id}, ' @@ -128,17 +155,13 @@ class SequenceGroupMetadata: self, group_id: int, is_prompt: bool, - input_tokens: Dict[int, List[int]], # Seq id -> token ids. - context_len: int, - seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. + seq_data: Dict[int, SequenceData], # Seq id -> sequence data. sampling_params: SamplingParams, - block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers. + block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers. ) -> None: self.group_id = group_id self.is_prompt = is_prompt - self.input_tokens = input_tokens - self.context_len = context_len - self.seq_logprobs = seq_logprobs + self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 3298a3bc..90d5d7af 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Optional, Tuple import torch @@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import ( initialize_all_reduce_launcher, get_tensor_model_parallel_world_size) from cacheflow.sampling_params import SamplingParams -from cacheflow.sequence import SequenceGroupMetadata -from cacheflow.sequence import SequenceOutputs +from cacheflow.sequence import (SequenceData, SequenceGroupMetadata, + SequenceOutputs) from cacheflow.worker.cache_engine import CacheEngine @@ -72,7 +72,6 @@ class Worker: self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache - def init_distributed_environment(self, distributed_init_method: str, rank: int, @@ -96,7 +95,6 @@ class Worker: seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]: seq_groups: List[Tuple[List[int], SamplingParams]] = [] - seq_logprobs: Dict[int, float] = {} input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -107,15 +105,15 @@ class Worker: if not seq_group_metadata.is_prompt: continue - seq_ids = list(seq_group_metadata.input_tokens.keys()) + seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - seq_logprobs.update(seq_group_metadata.seq_logprobs) # Use any sequence in the group. seq_id = seq_ids[0] - prompt_tokens = seq_group_metadata.input_tokens[seq_id] + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) @@ -141,27 +139,26 @@ class Worker: if seq_group_metadata.is_prompt: continue - seq_ids = list(seq_group_metadata.input_tokens.keys()) + seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - seq_logprobs.update(seq_group_metadata.seq_logprobs) for seq_id in seq_ids: - assert len(seq_group_metadata.input_tokens[seq_id]) == 1 - generation_token = seq_group_metadata.input_tokens[seq_id][0] + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) - position = seq_group_metadata.context_len - 1 + context_len = seq_data.get_len() + position = context_len - 1 input_positions.append(position) block_table = seq_group_metadata.block_tables[seq_id] generation_block_tables.append(block_table) - max_context_len = max( - max_context_len, seq_group_metadata.context_len) + max_context_len = max(max_context_len, context_len) max_num_blocks_per_seq = max( max_num_blocks_per_seq, len(block_table)) - context_lens.append(seq_group_metadata.context_len) + context_lens.append(context_len) block_number = block_table[position // self.block_size] block_offset = position % self.block_size @@ -188,9 +185,13 @@ class Worker: block_tables_tensor = torch.tensor( padded_block_tables, dtype=torch.int, device='cuda') + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + input_metadata = InputMetadata( seq_groups=seq_groups, - seq_logprobs=seq_logprobs, + seq_data=seq_data, prompt_lens=prompt_lens, slot_mapping=slot_mapping_tensor, context_lens=context_lens_tensor, diff --git a/simple_server.py b/simple_server.py index 9644731c..c8cea42d 100644 --- a/simple_server.py +++ b/simple_server.py @@ -11,8 +11,8 @@ def main(args: argparse.Namespace): # Test the following inputs. test_inputs = [ ("A robot may not injure a human being", {}), # Use default parameters. - ("To be or not to be,", {"temperature": 0.8, "top_k": 5}), - ("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95}), + ("To be or not to be,", {"temperature": 0.8, "top_k": 5, "presence_penalty": 0.2}), + ("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95, "frequency_penalty": 0.1}), ("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}), ] while True: