From 8d66a7b6d747dd28fb29b998473f159ec08be2da Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 May 2023 00:58:31 -0700 Subject: [PATCH] Rename variables and methods (#91) --- cacheflow/block.py | 2 +- cacheflow/core/block_manager.py | 8 ++--- cacheflow/core/scheduler.py | 24 +++++++------- cacheflow/sampling_params.py | 12 ++----- cacheflow/sequence.py | 25 +++++++------- cacheflow/worker/controller.py | 18 ++-------- cacheflow/worker/worker.py | 58 ++++++++++++++++----------------- 7 files changed, 64 insertions(+), 83 deletions(-) diff --git a/cacheflow/block.py b/cacheflow/block.py index df8d46ab..01edf617 100644 --- a/cacheflow/block.py +++ b/cacheflow/block.py @@ -27,7 +27,7 @@ class LogicalTokenBlock: def is_full(self) -> bool: return self.num_tokens == self.block_size - def append(self, token_ids: List[int]) -> None: + def append_tokens(self, token_ids: List[int]) -> None: assert len(token_ids) <= self.get_num_empty_slots() self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids self.num_tokens += len(token_ids) diff --git a/cacheflow/core/block_manager.py b/cacheflow/core/block_manager.py index 0b188508..9e64e1d6 100644 --- a/cacheflow/core/block_manager.py +++ b/cacheflow/core/block_manager.py @@ -97,15 +97,15 @@ class BlockSpaceManager: for seq in seq_group.seqs: self.block_tables[seq.seq_id] = block_table.copy() - def can_append(self, seq_group: SequenceGroup) -> bool: + def can_append_slot(self, seq_group: SequenceGroup) -> bool: # Simple heuristic: If there is at least one free block # for each sequence, we can append. num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) return num_seqs <= num_free_gpu_blocks - def append(self, seq: Sequence) -> Optional[Tuple[int, int]]: - """Allocate a physical slot for the new token.""" + def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] @@ -156,7 +156,7 @@ class BlockSpaceManager: num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. - # NOTE: This should match the logic in can_append(). + # NOTE: This should match the logic in can_append_slot(). num_required_blocks = len(blocks) + num_swapped_seqs return num_free_blocks - num_required_blocks >= self.watermark_blocks diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 17c5c6d6..ba115313 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence from cacheflow.sequence import SequenceGroup -from cacheflow.sequence import SequenceGroupInputs +from cacheflow.sequence import SequenceGroupMetadata from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceStatus @@ -105,7 +105,7 @@ class Scheduler: preempted: List[SequenceGroup] = [] while self.running: seq_group = self.running.pop(0) - while not self.block_manager.can_append(seq_group): + while not self.block_manager.can_append_slot(seq_group): if self.running: # Preempt the lowest-priority sequence groups. victim_seq_group = self.running.pop(-1) @@ -119,7 +119,7 @@ class Scheduler: break else: # Append new slots to the sequence group. - self._append(seq_group, blocks_to_copy) + self._append_slot(seq_group, blocks_to_copy) running.append(seq_group) self.running = running @@ -143,7 +143,7 @@ class Scheduler: seq_group = self.swapped.pop(0) self._swap_in(seq_group, blocks_to_swap_in) - self._append(seq_group, blocks_to_copy) + self._append_slot(seq_group, blocks_to_copy) self.running.append(seq_group) num_batched_tokens = sum( @@ -252,7 +252,7 @@ class Scheduler: prompt_group_ids = scheduler_output[3] # Create input data structures. - input_seq_groups: List[SequenceGroupInputs] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] updated_seq_groups: List[SequenceGroup] = self.running.copy() for seq_group in self.running: @@ -274,7 +274,7 @@ class Scheduler: # sequence length seq_len = seq.get_len() - input_seq_group = SequenceGroupInputs( + seq_group_metadata = SequenceGroupMetadata( group_id=group_id, is_prompt=is_prompt, input_tokens=input_tokens, @@ -283,14 +283,14 @@ class Scheduler: sampling_params=self.sampling_params[group_id], block_tables=block_tables, ) - input_seq_groups.append(input_seq_group) + seq_group_metadata_list.append(seq_group_metadata) # Execute the first stage of the pipeline. - if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out: + if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out: # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) self.controllers[0].execute_stage( - input_seq_groups, + seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -330,7 +330,7 @@ class Scheduler: # Append a new token to the sequence. output = seq_outputs[seq.seq_id] - seq.append(output.output_token, output.logprobs) + seq.append_token(output.output_token, output.logprobs) # Check if the sequence has generated a stop token. if output.output_token in stop_token_ids: @@ -360,13 +360,13 @@ class Scheduler: if seq_group.group_id not in self.num_steps: self.num_steps[seq_group.group_id] = 0 - def _append( + def _append_slot( self, seq_group: SequenceGroup, blocks_to_copy: Dict[int, List[int]], ) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - ret = self.block_manager.append(seq) + ret = self.block_manager.append_slot(seq) if ret is not None: src_block, dst_block = ret if src_block in blocks_to_copy: diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 4daeaa48..8e64126a 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -1,4 +1,4 @@ -from typing import Optional, Set, Dict +from typing import Dict, Set class SamplingParams: @@ -12,7 +12,6 @@ class SamplingParams: stop_token_ids: Set[int], max_num_steps: int, num_logprobs: int, - context_window_size: Optional[int], ) -> None: if n < 1: raise ValueError(f'n must be at least 1, got {n}.') @@ -27,10 +26,6 @@ class SamplingParams: if num_logprobs < 0: raise ValueError( f'num_logprobs must be non-negative, got {num_logprobs}.') - if context_window_size is not None and context_window_size < 0: - raise ValueError( - 'context_window_size must be non-negative, ' - f'got {context_window_size}.') if use_beam_search: if n == 1: @@ -58,7 +53,6 @@ class SamplingParams: self.stop_token_ids = stop_token_ids self.max_num_steps = max_num_steps self.num_logprobs = num_logprobs - self.context_window_size = context_window_size def __repr__(self) -> str: return (f'SamplingParams(n={self.n}, ' @@ -67,8 +61,7 @@ class SamplingParams: f'use_beam_search={self.use_beam_search}, ' f'stop_token_ids={self.stop_token_ids}, ' f'max_num_steps={self.max_num_steps}, ' - f'num_logprobs={self.num_logprobs}, ' - f'context_window_size={self.context_window_size})') + f'num_logprobs={self.num_logprobs}') @classmethod def from_dict(cls, d: Dict) -> 'SamplingParams': @@ -80,5 +73,4 @@ class SamplingParams: stop_token_ids=set(d.get('stop_token_ids', set())), max_num_steps=d.get('max_num_steps', 16), num_logprobs=d.get('num_logprobs', 0), - context_window_size=d.get('context_window_size', None), ) diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 6f5501a9..62d3ef30 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -18,45 +18,46 @@ class Sequence: def __init__( self, seq_id: int, - token_ids: List[int], + prompt_token_ids: List[int], block_size: int, ) -> None: self.seq_id = seq_id self.block_size = block_size + self.prompt_len = len(prompt_token_ids) self.logical_token_blocks: List[LogicalTokenBlock] = [] - # Initialize the logical token blocks with the given token ids. - self.add(token_ids) + # Initialize the logical token blocks with the prompt token ids. + self._append_tokens(prompt_token_ids) - self.prompt_len = len(token_ids) self.status = SequenceStatus.WAITING + # Used for beam search. self.output_logprobs: List[Dict[int, float]] = [] self.cumulative_logprobs = 0.0 - def add_block(self) -> None: + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), block_size=self.block_size, ) self.logical_token_blocks.append(block) - def add(self, token_ids: List[int]) -> None: + def _append_tokens(self, token_ids: List[int]) -> None: while token_ids: if not self.logical_token_blocks: - self.add_block() + self._append_logical_block() last_block = self.logical_token_blocks[-1] if last_block.is_full(): - self.add_block() + self._append_logical_block() last_block = self.logical_token_blocks[-1] num_empty_slots = last_block.get_num_empty_slots() - last_block.append(token_ids[:num_empty_slots]) + last_block.append_tokens(token_ids[:num_empty_slots]) token_ids = token_ids[num_empty_slots:] - def append(self, token_id: int, logprobs: Dict[int, float]) -> None: + def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None: assert token_id in logprobs - self.add([token_id]) + self._append_tokens([token_id]) self.output_logprobs.append(logprobs) self.cumulative_logprobs += logprobs[token_id] @@ -121,7 +122,7 @@ class SequenceGroup: f'num_seqs={len(self.seqs)})') -class SequenceGroupInputs: +class SequenceGroupMetadata: def __init__( self, diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index 46224727..018259b8 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union, Tuple, Optional +from typing import List, Optional, Tuple, Union try: import ray @@ -6,7 +6,6 @@ except ImportError: ray = None from cacheflow.core.scheduler import Scheduler -from cacheflow.sequence import SequenceGroupInputs from cacheflow.worker.worker import Worker @@ -81,23 +80,12 @@ class Controller: self.next_node = next_node self.is_last_stage = isinstance(next_node, Scheduler) - def execute_stage( - self, - input_seq_groups: List[SequenceGroupInputs], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> None: + def execute_stage(self, *args, **kwargs) -> None: all_outputs = [] for worker in self.workers: executor = (worker.execute_stage.remote if self.use_ray else worker.execute_stage) - output = executor( - input_seq_groups, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - ) + output = executor(*args, **kwargs) all_outputs.append(output) if self.use_ray: diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 9f81f207..dc864934 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -8,10 +8,11 @@ from cacheflow.model_executor.parallel_utils.parallel_state import ( initialize_all_reduce_launcher, get_tensor_model_parallel_world_size) from cacheflow.sampling_params import SamplingParams -from cacheflow.sequence import SequenceGroupInputs +from cacheflow.sequence import SequenceGroupMetadata from cacheflow.sequence import SequenceOutputs from cacheflow.worker.cache_engine import CacheEngine + class Worker: def __init__( @@ -93,30 +94,29 @@ class Worker: def prepare_inputs( self, - input_seq_groups: List[SequenceGroupInputs], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]: seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_logprobs: Dict[int, float] = {} - sampling_params: Dict[int, SamplingParams] = {} input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] # Add prompt tokens. prompt_lens: List[int] = [] - for input_seq_group in input_seq_groups: - if not input_seq_group.is_prompt: + for seq_group_metadata in seq_group_metadata_list: + if not seq_group_metadata.is_prompt: continue - seq_ids = list(input_seq_group.input_tokens.keys()) - sampling_params = input_seq_group.sampling_params + seq_ids = list(seq_group_metadata.input_tokens.keys()) + sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - seq_logprobs.update(input_seq_group.seq_logprobs) + seq_logprobs.update(seq_group_metadata.seq_logprobs) # Use any sequence in the group. seq_id = seq_ids[0] - prompt_tokens = input_seq_group.input_tokens[seq_id] + prompt_tokens = seq_group_metadata.input_tokens[seq_id] prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) @@ -126,7 +126,7 @@ class Worker: input_positions.extend(range(len(prompt_tokens))) # Compute the slot mapping. - block_table = input_seq_group.block_tables[seq_id] + block_table = seq_group_metadata.block_tables[seq_id] for i in range(prompt_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size @@ -138,31 +138,31 @@ class Worker: max_num_blocks_per_seq = 0 context_lens: List[int] = [] generation_block_tables: List[List[int]] = [] - for input_seq_group in input_seq_groups: - if input_seq_group.is_prompt: + for seq_group_metadata in seq_group_metadata_list: + if seq_group_metadata.is_prompt: continue - seq_ids = list(input_seq_group.input_tokens.keys()) - sampling_params = input_seq_group.sampling_params + seq_ids = list(seq_group_metadata.input_tokens.keys()) + sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - seq_logprobs.update(input_seq_group.seq_logprobs) + seq_logprobs.update(seq_group_metadata.seq_logprobs) for seq_id in seq_ids: - assert len(input_seq_group.input_tokens[seq_id]) == 1 - generation_token = input_seq_group.input_tokens[seq_id][0] + assert len(seq_group_metadata.input_tokens[seq_id]) == 1 + generation_token = seq_group_metadata.input_tokens[seq_id][0] input_tokens.append(generation_token) - position = input_seq_group.context_len - 1 + position = seq_group_metadata.context_len - 1 input_positions.append(position) - block_table = input_seq_group.block_tables[seq_id] + block_table = seq_group_metadata.block_tables[seq_id] generation_block_tables.append(block_table) max_context_len = max( - max_context_len, input_seq_group.context_len) + max_context_len, seq_group_metadata.context_len) max_num_blocks_per_seq = max( max_num_blocks_per_seq, len(block_table)) - context_lens.append(input_seq_group.context_len) + context_lens.append(seq_group_metadata.context_len) block_number = block_table[position // self.block_size] block_offset = position % self.block_size @@ -203,30 +203,30 @@ class Worker: @torch.inference_mode() def execute_stage( self, - input_seq_groups: List[SequenceGroupInputs], + seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ) -> Dict[int, SequenceOutputs]: # Issue cache operations. - command_issued = False + issued_cache_op = False if blocks_to_swap_in: self.cache_engine.swap_in(blocks_to_swap_in) - command_issued = True + issued_cache_op = True if blocks_to_swap_out: self.cache_engine.swap_out(blocks_to_swap_out) - command_issued = True + issued_cache_op = True if blocks_to_copy: self.cache_engine.copy(blocks_to_copy) - command_issued = True + issued_cache_op = True - if command_issued: + if issued_cache_op: cache_events = self.cache_events else: cache_events = None # If there is no input, we don't need to execute the model. - if not input_seq_groups: + if not seq_group_metadata_list: if cache_events is not None: for event in cache_events: event.wait() @@ -234,7 +234,7 @@ class Worker: # Prepare input tensors. input_tokens, input_positions, input_metadata = self.prepare_inputs( - input_seq_groups) + seq_group_metadata_list) # Execute the model. output = self.model(