diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index ba115313..7656ae97 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -5,6 +5,7 @@ import time from typing import Any, Dict, List, Optional, Tuple from cacheflow.core.block_manager import BlockSpaceManager +from cacheflow.logger import init_logger from cacheflow.core.policy import PolicyFactory from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence @@ -14,6 +15,10 @@ from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceStatus +logger = init_logger(__name__) +_LOGGING_INTERVAL_SEC = 10 + + class PreemptionMode(enum.Enum): """Preemption modes. @@ -37,8 +42,7 @@ class Scheduler: num_cpu_blocks: int, max_num_batched_tokens: int, max_num_sequences: int, - collect_stats: bool, - do_memory_analysis: bool = False, + log_stats: bool, ) -> None: self.controllers = controllers self.block_size = block_size @@ -46,8 +50,7 @@ class Scheduler: self.num_cpu_blocks = num_cpu_blocks self.max_num_batched_tokens = max_num_batched_tokens self.max_num_sequences = max_num_sequences - self.collect_stats = collect_stats - self.do_memory_analysis = do_memory_analysis + self.log_stats = log_stats # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name='fcfs') @@ -69,8 +72,9 @@ class Scheduler: # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] - # Performance-related statistics. - self.stats = Stats(num_gpu_blocks, num_cpu_blocks) + self.last_logging_time: float = 0.0 + # List[timestamp, num_tokens] + self.num_input_tokens: List[Tuple[float, int]] = [] def add_sequence_groups( self, @@ -186,59 +190,46 @@ class Scheduler: num_batched_tokens += num_prompt_tokens prompt_group_ids.append(seq_group.group_id) - if self.collect_stats: - if self.running or blocks_to_swap_in or blocks_to_swap_out: - self.stats.timestamps.append(now - self.stats.start_time) - self.stats.input_lens.append(num_batched_tokens) - self.stats.swap_out_lens.append(len(blocks_to_swap_out) * self.block_size) - self.stats.swap_in_lens.append(len(blocks_to_swap_in) * self.block_size) - self.stats.num_preemption.append(len(preempted)) - self.stats.num_swapped.append(len(self.swapped)) - self.stats.num_running.append(len(self.running)) - self.stats.num_waiting.append(len(self.waiting)) + if not self.log_stats: + return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, + prompt_group_ids) - num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks() - num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks - self.stats.gpu_cache_usage.append(num_used_gpu_blocks / self.num_gpu_blocks) + now = time.time() + if num_batched_tokens > 0: + self.num_input_tokens.append((now, num_batched_tokens)) + elapsed_time = now - self.last_logging_time + if elapsed_time > _LOGGING_INTERVAL_SEC: + self.last_logging_time = now + self.num_input_tokens = [ + (t, n) for t, n in self.num_input_tokens + if now - t < _LOGGING_INTERVAL_SEC + ] + if len(self.num_input_tokens) > 1: + total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1]) + window = now - self.num_input_tokens[0][0] + avg_throughput = total_num_tokens / window + else: + avg_throughput = 0.0 + + num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks() + num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks + gpu_cache_usage = num_used_gpu_blocks / self.num_gpu_blocks + if self.num_cpu_blocks > 0: num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks() num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks - self.stats.cpu_cache_usage.append(num_used_cpu_blocks / self.num_cpu_blocks) + cpu_cache_usage = num_used_cpu_blocks / self.num_cpu_blocks + else: + cpu_cache_usage = 0.0 - if self.do_memory_analysis: - block_tables = self.block_manager.block_tables - num_logical_blocks = 0 - num_logical_tokens = 0 - num_physical_blocks = 0 - num_physical_tokens = 0 - physical_block_numbers = set() - num_reserved_tokens = 0 - for seq_group in self.running: - group_id = seq_group.group_id - sampling_params = self.sampling_params[group_id] - max_num_steps = sampling_params.max_num_steps - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - num_logical_blocks += len(seq.logical_token_blocks) - num_logical_tokens += seq.get_len() + logger.info( + f"Throughput: {avg_throughput:.1f} tokens/s, " + f"Running: {len(self.running)} reqs, " + f"Swapped: {len(self.swapped)} reqs, " + f"Pending: {len(self.waiting)} reqs, " + f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " + f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") - seq_id = seq.seq_id - block_table = block_tables[seq_id] - for i, block in enumerate(block_table): - if block.block_number in physical_block_numbers: - continue - physical_block_numbers.add(block.block_number) - num_physical_blocks += 1 - num_physical_tokens += seq.logical_token_blocks[i].num_tokens - - assert num_physical_blocks == num_used_gpu_blocks - self.stats.num_logical_blocks.append(num_logical_blocks) - self.stats.num_logical_tokens.append(num_logical_tokens) - self.stats.num_physical_blocks.append(num_physical_blocks) - self.stats.num_physical_tokens.append(num_physical_tokens) - self.stats.num_reserved_tokens.append(num_reserved_tokens) - - return (blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, + return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, prompt_group_ids) def step(self) -> List[SequenceGroup]: @@ -455,75 +446,3 @@ class Scheduler: blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED - - def reset_stats(self) -> None: - self.stats.reset(self.num_gpu_blocks, self.num_cpu_blocks) - - def save_stats( - self, - output_dir: str, - ) -> None: - assert self.collect_stats, 'Statistics collection is disabled.' - self.stats.save(output_dir) - - -class Stats: - - def __init__( - self, - num_gpu_blocks: int, - num_cpu_blocks: int, - ) -> None: - self.start_time: float = time.time() - self.num_gpu_blocks = num_gpu_blocks - self.num_cpu_blocks = num_cpu_blocks - - self.timestamps: List[float] = [] - self.input_lens: List[int] = [] - self.swap_out_lens: List[int] = [] - self.swap_in_lens: List[int] = [] - self.num_preemption: List[int] = [] - self.num_waiting: List[int] = [] - self.num_running: List[int] = [] - self.num_swapped: List[int] = [] - self.gpu_cache_usage: List[float] = [] - self.cpu_cache_usage: List[float] = [] - - self.num_logical_blocks: List[int] = [] - self.num_logical_tokens: List[int] = [] - self.num_physical_blocks: List[int] = [] - self.num_physical_tokens: List[int] = [] - self.num_reserved_tokens: List[int] = [] - - def reset( - self, - num_gpu_blocks: int, - num_cpu_blocks: int, - ) -> None: - self.__init__(num_gpu_blocks, num_cpu_blocks) - - def to_dict(self) -> Dict[str, Any]: - return { - 'start_time': self.start_time, - 'num_gpu_blocks': self.num_gpu_blocks, - 'num_cpu_blocks': self.num_cpu_blocks, - 'timestamps': self.timestamps, - 'input_lens': self.input_lens, - 'swap_out_lens': self.swap_out_lens, - 'swap_in_lens': self.swap_in_lens, - 'num_preemption': self.num_preemption, - 'num_waiting': self.num_waiting, - 'num_running': self.num_running, - 'num_swapped': self.num_swapped, - 'gpu_cache_usage': self.gpu_cache_usage, - 'cpu_cache_usage': self.cpu_cache_usage, - 'num_logical_blocks': self.num_logical_blocks, - 'num_logical_tokens': self.num_logical_tokens, - 'num_physical_blocks': self.num_physical_blocks, - 'num_physical_tokens': self.num_physical_tokens, - 'num_reserved_tokens': self.num_reserved_tokens, - } - - def save(self, output_dir: str) -> None: - with open(os.path.join(output_dir, 'stats.pkl'), 'wb') as f: - pickle.dump(self.to_dict(), f) diff --git a/cacheflow/core/server.py b/cacheflow/core/server.py index 9eb96efd..2f968f8b 100644 --- a/cacheflow/core/server.py +++ b/cacheflow/core/server.py @@ -44,18 +44,16 @@ class Server: gpu_memory: int, cpu_memory: int, use_ray: bool, - collect_stats: bool = False, - do_memory_analysis: bool = False, + log_stats: bool, ): logger.info( "Initializing a server with config: " f"model={model!r}, " f"dtype={dtype}, " f"use_dummy_weights={use_dummy_weights}, " - f"cache_dir={cache_dir}, " + f"cache_dir={cache_dir!r}, " f"use_np_cache={use_np_cache}, " f"tensor_parallel_size={tensor_parallel_size}, " - f"block_size={block_size}, " f"seed={seed})" ) self.num_nodes = num_nodes @@ -111,8 +109,7 @@ class Server: num_cpu_blocks=self.num_cpu_blocks, max_num_batched_tokens=max_num_batched_tokens, max_num_sequences=max_num_sequences, - collect_stats=collect_stats, - do_memory_analysis=do_memory_analysis, + log_stats=log_stats, ) # Connect the controllers. for i in range(len(self.controllers) - 1): @@ -244,6 +241,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration') parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration') + parser.add_argument('--log-stats', action='store_true', help='log system statistics') return parser @@ -286,6 +284,7 @@ def init_local_server_and_frontend_with_arguments(args: argparse.Namespace): gpu_memory=get_gpu_memory(), cpu_memory=get_cpu_memory(), use_ray=args.use_ray, + log_stats=args.log_stats, ) # Create a frontend. diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index dc864934..3298a3bc 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -91,7 +91,6 @@ class Worker: initialize_model_parallel(tensor_parallel_size, pipeline_parallel_size) - def prepare_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata],