From f756799b84f5558c82c7a049069f845b31573e9e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 19 May 2023 11:35:44 -0600 Subject: [PATCH] Use runtime profiling to replace manual memory analyzers (#81) --- cacheflow/core/server.py | 52 +-- cacheflow/frontend/fastapi_frontend.py | 10 +- cacheflow/model_executor/__init__.py | 7 +- cacheflow/model_executor/layers/attention.py | 50 +-- cacheflow/model_executor/layers/sampler.py | 2 +- cacheflow/model_executor/memory_analyzer.py | 370 ------------------- cacheflow/model_executor/model_loader.py | 37 -- cacheflow/model_executor/models/gpt2.py | 3 +- cacheflow/model_executor/models/gpt_neox.py | 3 +- cacheflow/model_executor/models/llama.py | 3 +- cacheflow/model_executor/models/opt.py | 3 +- cacheflow/model_executor/utils.py | 12 + cacheflow/worker/controller.py | 45 ++- cacheflow/worker/worker.py | 92 ++++- 14 files changed, 211 insertions(+), 478 deletions(-) delete mode 100644 cacheflow/model_executor/memory_analyzer.py diff --git a/cacheflow/core/server.py b/cacheflow/core/server.py index 46b75d81..65144b2e 100644 --- a/cacheflow/core/server.py +++ b/cacheflow/core/server.py @@ -6,15 +6,14 @@ try: import ray except ImportError: ray = None +import numpy as np import torch from cacheflow.core.scheduler import Scheduler from cacheflow.frontend.simple_frontend import SimpleFrontend from cacheflow.logger import init_logger -from cacheflow.model_executor import get_memory_analyzer from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import SequenceGroup -from cacheflow.utils import get_gpu_memory, get_cpu_memory from cacheflow.worker.controller import Controller, DeviceID logger = init_logger(__name__) @@ -34,14 +33,13 @@ class Server: dtype: str, seed: int, swap_space: int, + gpu_memory_utilization: float, max_num_batched_tokens: int, max_num_sequences: int, num_nodes: int, num_devices_per_node: int, distributed_init_method: str, all_stage_devices: List[List[DeviceID]], - gpu_memory: int, - cpu_memory: int, use_ray: bool, log_stats: bool, ): @@ -63,19 +61,6 @@ class Server: assert self.world_size == 1, ( "Only support single GPU without Ray.") - self.memory_analyzer = get_memory_analyzer( - model_name=model, - block_size=block_size, - dtype=dtype, - gpu_memory=gpu_memory, - cpu_memory=cpu_memory, - tensor_parallel_size=tensor_parallel_size, - ) - self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks( - max_num_batched_tokens=max_num_batched_tokens) - self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks( - swap_space_gib=swap_space) - # Create a controller for each pipeline stage. self.controllers: List[Controller] = [] for i in range(pipeline_parallel_size): @@ -87,19 +72,35 @@ class Server: tensor_parallel_size=tensor_parallel_size, distributed_init_method=distributed_init_method, model_name=model, - block_size=block_size, - num_gpu_blocks=self.num_gpu_blocks, - num_cpu_blocks=self.num_cpu_blocks, dtype=dtype, seed=seed, cache_dir=cache_dir, use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache, max_num_batched_tokens=max_num_batched_tokens, + max_num_sequences=max_num_sequences, use_ray=use_ray, ) self.controllers.append(controller) + # Initialize cache engine. + all_worker_num_available_blocks = [] + for controller in self.controllers: + all_worker_num_available_blocks.extend( + controller.get_num_available_blocks( + block_size, swap_space, gpu_memory_utilization) + ) + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + self.num_gpu_blocks = np.min([b[0] for b in all_worker_num_available_blocks]) + self.num_cpu_blocks = np.min([b[1] for b in all_worker_num_available_blocks]) + logger.info(f'# GPU blocks: {self.num_gpu_blocks}, ' + f'# CPU blocks: {self.num_cpu_blocks}') + for controller in self.controllers: + controller.init_cache_engine(block_size, self.num_gpu_blocks, + self.num_cpu_blocks) + # Create a scheduler. self.scheduler = Scheduler( controllers=self.controllers, @@ -214,7 +215,11 @@ def initialize_cluster( all_stage_devices) +_GiB = 1 << 30 + + def add_server_arguments(parser: argparse.ArgumentParser): + """Shared arguments for CacheFlow servers.""" # Model arguments parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--cache-dir', type=str, default=None, @@ -238,6 +243,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') + parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor') parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration') parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration') parser.add_argument('--log-stats', action='store_true', help='log system statistics') @@ -245,8 +251,11 @@ def add_server_arguments(parser: argparse.ArgumentParser): def process_server_arguments(args: argparse.Namespace): + """Post process the parsed arguments.""" if args.pipeline_parallel_size * args.tensor_parallel_size > 1: args.use_ray = True + args.swap_space = args.swap_space * _GiB + args.max_num_sequences = min(args.max_num_sequences, args.max_num_batched_tokens) return args @@ -274,14 +283,13 @@ def init_local_server_and_frontend_with_arguments(args: argparse.Namespace): dtype=args.dtype, seed=args.seed, swap_space=args.swap_space, + gpu_memory_utilization=args.gpu_memory_utilization, max_num_batched_tokens=args.max_num_batched_tokens, max_num_sequences=args.max_num_sequences, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, - gpu_memory=get_gpu_memory(), - cpu_memory=get_cpu_memory(), use_ray=args.use_ray, log_stats=args.log_stats, ) diff --git a/cacheflow/frontend/fastapi_frontend.py b/cacheflow/frontend/fastapi_frontend.py index 9dabd4db..cb7fcc2b 100644 --- a/cacheflow/frontend/fastapi_frontend.py +++ b/cacheflow/frontend/fastapi_frontend.py @@ -15,7 +15,7 @@ from cacheflow.core.server import (Server, add_server_arguments, from cacheflow.frontend.utils import get_tokenizer from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence, SequenceGroup -from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory +from cacheflow.utils import Counter from cacheflow.worker.controller import DeviceID TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds @@ -34,6 +34,7 @@ class FastAPIServer: dtype: str, seed: int, swap_space: int, + gpu_memory_utilization: float, max_num_batched_tokens: int, max_num_sequences: int, num_nodes: int, @@ -41,6 +42,7 @@ class FastAPIServer: distributed_init_method: str, all_stage_devices: List[List[DeviceID]], server_use_ray: bool, + log_stats: bool, ): self.block_size = block_size @@ -62,15 +64,15 @@ class FastAPIServer: dtype=dtype, seed=seed, swap_space=swap_space, + gpu_memory_utilization=gpu_memory_utilization, max_num_batched_tokens=max_num_batched_tokens, max_num_sequences=max_num_sequences, num_nodes=num_nodes, num_devices_per_node=num_devices_per_node, distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, - gpu_memory=get_gpu_memory(), - cpu_memory=get_cpu_memory(), use_ray=server_use_ray, + log_stats=log_stats, ) self.running_seq_groups: Dict[int, SequenceGroup] = {} @@ -182,6 +184,7 @@ if __name__ == "__main__": dtype=args.dtype, seed=args.seed, swap_space=args.swap_space, + gpu_memory_utilization=args.gpu_memory_utilization, max_num_batched_tokens=args.max_num_batched_tokens, max_num_sequences=args.max_num_sequences, num_nodes=num_nodes, @@ -189,6 +192,7 @@ if __name__ == "__main__": distributed_init_method=distributed_init_method, all_stage_devices=all_stage_devices, server_use_ray=args.use_ray, + log_stats=args.log_stats, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/model_executor/__init__.py b/cacheflow/model_executor/__init__.py index d839756a..acb84aca 100644 --- a/cacheflow/model_executor/__init__.py +++ b/cacheflow/model_executor/__init__.py @@ -1,11 +1,12 @@ from cacheflow.model_executor.input_metadata import InputMetadata -from cacheflow.model_executor.model_loader import get_model, get_memory_analyzer -from cacheflow.model_executor.utils import set_random_seed +from cacheflow.model_executor.model_loader import get_model +from cacheflow.model_executor.utils import (set_random_seed, + get_cache_block_size) __all__ = [ "InputMetadata", + "get_cache_block_size", "get_model", - "get_memory_analyzer", "set_random_seed", ] diff --git a/cacheflow/model_executor/layers/attention.py b/cacheflow/model_executor/layers/attention.py index 82c9b1cd..7c178848 100644 --- a/cacheflow/model_executor/layers/attention.py +++ b/cacheflow/model_executor/layers/attention.py @@ -11,6 +11,8 @@ from cacheflow import pos_encoding_ops from cacheflow.model_executor.input_metadata import InputMetadata +_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256] + class GPTCacheFlowAttention(nn.Module): """GPT-style multi-head attention. @@ -39,11 +41,19 @@ class GPTCacheFlowAttention(nn.Module): 5. Output a flattened 1D tensor. """ - def __init__(self, scale: float) -> None: + def __init__(self, num_heads: int, head_size: int, scale: float) -> None: super().__init__() + self.num_heads = num_heads + self.head_size = head_size self.scale = float(scale) self.attn_op = xops.fmha.cutlass.FwOp() + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f'head_size ({self.head_size}) is not supported by ' + 'the single_query_cached_kv_attention kernel. ' + 'Use one of the following head sizes: ' + f'{_SUPPORTED_HEAD_SIZES}.') + def multi_query_kv_attention( self, output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] @@ -74,14 +84,6 @@ class GPTCacheFlowAttention(nn.Module): value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] input_metadata: InputMetadata, ) -> None: - head_size = value_cache.shape[2] - supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256] - if head_size not in supported_head_sizes: - raise ValueError(f'head_size ({head_size}) is not supported by ' - 'the single_query_cached_kv_attention kernel. ' - 'Use one of the following head sizes: ' - f'{supported_head_sizes}.') - block_size = value_cache.shape[3] attention_ops.single_query_cached_kv_attention( output, @@ -100,8 +102,8 @@ class GPTCacheFlowAttention(nn.Module): query: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor, # [num_tokens, num_heads * head_size] - key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size] input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] @@ -109,11 +111,9 @@ class GPTCacheFlowAttention(nn.Module): # tensor of shape [num_tokens, 3 * num_heads * head_size]. # Reshape the query, key, and value tensors. - num_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_heads, head_size) - value = value.view(-1, num_heads, head_size) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) # Pre-allocate the output tensor. output = torch.empty_like(query) @@ -134,8 +134,11 @@ class GPTCacheFlowAttention(nn.Module): cache_event.wait() # Reshape the keys and values and store them in the cache. + # When key_cache and value_cache are not provided, the new key + # and value vectors will not be cached. num_valid_tokens = input_metadata.num_valid_tokens - if num_valid_tokens > 0: + if (num_valid_tokens > 0 and key_cache is not None + and value_cache is not None): # The stride is 3 because the key and value are sliced from qkv. cache_ops.reshape_and_cache( key[:num_valid_tokens], @@ -146,6 +149,10 @@ class GPTCacheFlowAttention(nn.Module): ) if input_metadata.num_generation_tokens > 0: + assert key_cache is not None and value_cache is not None, ( + "key_cache and value_cache must be provided when " + "generating tokens." + ) # Compute the attention op for generation tokens. self.single_query_cached_kv_attention( output[num_prompt_tokens:num_valid_tokens], @@ -156,7 +163,7 @@ class GPTCacheFlowAttention(nn.Module): # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. - return output.view(-1, num_heads * head_size) + return output.view(-1, self.num_heads * self.head_size) class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): @@ -164,12 +171,14 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): def __init__( self, + num_heads: int, + head_size: int, scale: float, rotary_dim: int, max_position: int = 8192, base: int = 10000, ) -> None: - super().__init__(scale) + super().__init__(num_heads, head_size, scale) # Create the cos and sin cache. inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) @@ -199,12 +208,11 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): ) -> torch.Tensor: # [num_tokens, num_heads * head_size] # Apply rotary embedding to the query and key before passing them # to the attention op. - head_size = value_cache.shape[2] pos_encoding_ops.rotary_embedding_neox( positions, query, key, - head_size, + self.head_size, self.cos_sin_cache, ) return super().forward( diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index 7321dbf3..84cccadb 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -74,7 +74,7 @@ class Sampler(nn.Module): # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) assert len(top_ps) == len(top_ks) == probs.shape[0] - if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks): + if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks): probs = _apply_top_p_top_k(probs, top_ps, top_ks) # Sample the next tokens. diff --git a/cacheflow/model_executor/memory_analyzer.py b/cacheflow/model_executor/memory_analyzer.py deleted file mode 100644 index fb910e64..00000000 --- a/cacheflow/model_executor/memory_analyzer.py +++ /dev/null @@ -1,370 +0,0 @@ -import torch -from transformers import AutoConfig - -from cacheflow.logger import init_logger -from cacheflow.model_executor.utils import get_dtype_size - -logger = init_logger(__name__) - -_GiB = 1 << 30 - - -class CacheFlowMemoryAnalyzer: - - def get_max_num_gpu_blocks( - self, - max_num_batched_tokens: int, - memory_utilization: float, - ) -> int: - raise NotImplementedError() - - def get_workspace_size(self) -> int: - return 1 * _GiB - - def get_cache_block_size(self) -> int: - raise NotImplementedError() - - def get_max_num_cpu_blocks( - self, - swap_space_gib: int, - ) -> int: - swap_space = swap_space_gib * _GiB - cpu_memory = self.cpu_memory - if swap_space > 0.8 * cpu_memory: - raise ValueError(f'The swap space ({swap_space_gib:.2f} GiB) ' - 'takes more than 80% of the available memory ' - f'({cpu_memory / _GiB:.2f} GiB).' - 'Please check the swap space size.') - if swap_space > 0.5 * cpu_memory: - logger.info(f'WARNING: The swap space ({swap_space_gib:.2f} GiB) ' - 'takes more than 50% of the available memory ' - f'({cpu_memory / _GiB:.2f} GiB).' - 'This may slow the system performance.') - max_num_blocks = swap_space // self.get_cache_block_size() - return max_num_blocks - - def get_param_size(self) -> int: - raise NotImplementedError() - - def get_max_act_size(self, max_num_batched_tokens: int) -> int: - raise NotImplementedError() - - def get_cache_block_size(self) -> int: - key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size - value_cache_block = key_cache_block - total = self.num_layers * (key_cache_block + value_cache_block) - dtype_size = get_dtype_size(self.dtype) - return dtype_size * total - - def get_max_num_gpu_blocks( - self, - max_num_batched_tokens: int, - memory_utilization: float = 0.95, - ) -> int: - # NOTE(woosuk): This assumes that the machine has homogeneous GPUs. - usable_memory = int(memory_utilization * self.gpu_memory) - - param_size = self.get_param_size() - act_size = self.get_max_act_size(max_num_batched_tokens) - workspace_size = self.get_workspace_size() - - max_cache_size = usable_memory - (param_size + act_size + workspace_size) - if max_cache_size <= 0: - raise RuntimeError('Not enough GPU memory.') - max_num_blocks = max_cache_size // self.get_cache_block_size() - return max_num_blocks - - -class GPT2MemoryAnalyzer(CacheFlowMemoryAnalyzer): - - def __init__( - self, - model_name: str, - block_size: int, - dtype: torch.dtype, - gpu_memory: int, - cpu_memory: int, - tensor_parallel_size: int, - ) -> None: - self.model_name = model_name - self.block_size = block_size - self.dtype = dtype - self.gpu_memory = gpu_memory - self.cpu_memory = cpu_memory - self.tensor_parallel_size = tensor_parallel_size - - config = AutoConfig.from_pretrained(model_name) - self.num_layers = config.num_hidden_layers - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_size = config.hidden_size // self.num_heads - self.ffn_size = config.n_inner if config.n_inner is not None else 4 * self.hidden_size - self.vocab_size = config.vocab_size - self.max_position = config.max_position_embeddings - - def get_param_size(self) -> int: - word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size - position_embedding = self.max_position * self.hidden_size - - ln1 = 2 * self.hidden_size - q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - mha = ln1 + q + k + v + out - - ln2 = 2 * self.hidden_size - ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size - ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - ffn = ln2 + ffn1 + ffn2 - - total = (word_embedding + position_embedding + - self.num_layers * (mha + ffn)) - dtype_size = get_dtype_size(self.dtype) - return dtype_size * total - - def get_max_act_size( - self, - max_num_batched_tokens: int, - ) -> int: - # NOTE: We approxmiately calculate the maximum activation size by - # estimating - # 1) the maximum activation tensor size during inference - # 2) the residual tensor size during inference - # Here, we assume that FlashAttention is used and - # thus the attention maps are never materialized in GPU DRAM. - residual = max_num_batched_tokens * self.hidden_size - qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size - ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size - # Double the activation size for input and output. - max_act = 2 * (max(qkv, ffn) + residual) - # Size of output logits. - output_logits = 2 * (max_num_batched_tokens * self.vocab_size) - max_act = max(max_act, output_logits) - dtype_size = get_dtype_size(self.dtype) - return dtype_size * max_act - - -class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer): - - def __init__( - self, - model_name: str, - block_size: int, - dtype: torch.dtype, - gpu_memory: int, - cpu_memory: int, - tensor_parallel_size: int, - ) -> None: - self.model_name = model_name - self.block_size = block_size - self.dtype = dtype - self.gpu_memory = gpu_memory - self.cpu_memory = cpu_memory - self.tensor_parallel_size = tensor_parallel_size - - config = AutoConfig.from_pretrained(model_name) - self.num_layers = config.num_hidden_layers - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_size = config.hidden_size // self.num_heads - self.ffn_size = config.ffn_dim - self.embedding_size = config.word_embed_proj_dim - self.vocab_size = config.vocab_size - self.max_position = config.max_position_embeddings - - def get_param_size(self) -> int: - word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size - if self.embedding_size != self.hidden_size: - # Project in/out. - word_embedding += 2 * self.embedding_size * self.hidden_size - position_embedding = self.max_position * self.hidden_size - - ln1 = 2 * self.hidden_size - q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - mha = ln1 + q + k + v + out - - ln2 = 2 * self.hidden_size - ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size - ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - ffn = ln2 + ffn1 + ffn2 - - total = (word_embedding + position_embedding + - self.num_layers * (mha + ffn)) - dtype_size = get_dtype_size(self.dtype) - return dtype_size * total - - def get_max_act_size( - self, - max_num_batched_tokens: int, - ) -> int: - # NOTE: We approxmiately calculate the maximum activation size by - # estimating - # 1) the maximum activation tensor size during inference - # 2) the residual tensor size during inference - # Here, we assume that we use memory-efficient attention which - # does not materialize the attention maps in GPU DRAM. - residual = max_num_batched_tokens * self.hidden_size - qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size - ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size - # Double the activation size for input and output. - max_act = 2 * (max(qkv, ffn) + residual) - # Size of output logits. - output_logits = 2 * (max_num_batched_tokens * self.vocab_size) - max_act = max(max_act, output_logits) - dtype_size = get_dtype_size(self.dtype) - return dtype_size * max_act - - -class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer): - - def __init__( - self, - model_name: str, - block_size: int, - dtype: torch.dtype, - gpu_memory: int, - cpu_memory: int, - tensor_parallel_size: int, - ) -> None: - self.model_name = model_name - self.block_size = block_size - self.dtype = dtype - self.gpu_memory = gpu_memory - self.cpu_memory = cpu_memory - self.tensor_parallel_size = tensor_parallel_size - - config = AutoConfig.from_pretrained(model_name) - self.num_layers = config.num_hidden_layers - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_size = config.hidden_size // self.num_heads - self.ffn_size = config.intermediate_size - self.vocab_size = config.vocab_size - self.max_position = 8192 - - def get_param_size(self) -> int: - # NOTE: LLaMA does not tie the two embeddings. - word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size - lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size - - # NOTE: LLaMA does not have bias terms. - ln1 = self.hidden_size - q = self.hidden_size * self.hidden_size // self.tensor_parallel_size - k = self.hidden_size * self.hidden_size // self.tensor_parallel_size - v = self.hidden_size * self.hidden_size // self.tensor_parallel_size - out = self.hidden_size * self.hidden_size // self.tensor_parallel_size - # Rotary embedding. - # TODO(woosuk): Share the rotary embedding between layers. - rot = self.max_position * self.head_size - mha = ln1 + q + k + v + out + rot - - ln2 = self.hidden_size - gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size - down = self.ffn_size * self.hidden_size // self.tensor_parallel_size - up = self.hidden_size * self.ffn_size // self.tensor_parallel_size - ffn = ln2 + gate + down + up - - total = word_embedding + self.num_layers * (mha + ffn) + lm_head - dtype_size = get_dtype_size(self.dtype) - return dtype_size * total - - def get_max_act_size( - self, - max_num_batched_tokens: int, - ) -> int: - # NOTE: We approxmiately calculate the maximum activation size by - # estimating - # 1) the maximum activation tensor size during inference - # 2) the residual tensor size during inference - # Here, we assume that we use memory-efficient attention which - # does not materialize the attention maps in GPU DRAM. - residual = max_num_batched_tokens * self.hidden_size - qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size - ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size - # Double the activation size for input and output. - max_act = 2 * (max(qkv, ffn) + residual) - # Size of output logits. - output_logits = 2 * (max_num_batched_tokens * self.vocab_size) - max_act = max(max_act, output_logits) - dtype_size = get_dtype_size(self.dtype) - return dtype_size * max_act - - -class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer): - - def __init__( - self, - model_name: str, - block_size: int, - dtype: torch.dtype, - gpu_memory: int, - cpu_memory: int, - tensor_parallel_size: int, - ) -> None: - self.model_name = model_name - self.block_size = block_size - self.dtype = dtype - self.gpu_memory = gpu_memory - self.cpu_memory = cpu_memory - self.tensor_parallel_size = tensor_parallel_size - - config = AutoConfig.from_pretrained(model_name) - self.num_layers = config.num_hidden_layers - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_size = config.hidden_size // self.num_heads - self.ffn_size = config.intermediate_size - self.vocab_size = config.vocab_size - self.max_position = 8192 - self.tie_word_embeddings = config.tie_word_embeddings - - def get_param_size(self) -> int: - word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size - if self.tie_word_embeddings: - lm_head = 0 - else: - lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size - - ln1 = 2 * self.hidden_size - q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - # Rotary embedding. - # TODO(woosuk): Share the rotary embedding between layers. - rot = self.max_position * self.head_size - mha = ln1 + q + k + v + out + rot - - ln2 = 2 * self.hidden_size - ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size - ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size - ffn = ln2 + ffn1 + ffn2 - - total = word_embedding + self.num_layers * (mha + ffn) + lm_head - dtype_size = get_dtype_size(self.dtype) - return dtype_size * total - - def get_max_act_size( - self, - max_num_batched_tokens: int, - ) -> int: - # NOTE: We approxmiately calculate the maximum activation size by - # estimating - # 1) the maximum activation tensor size during inference - # 2) the residual tensor size during inference - # Here, we assume that we use memory-efficient attention which - # does not materialize the attention maps in GPU DRAM. - residual = max_num_batched_tokens * self.hidden_size - qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size - ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size - # Double the activation size for input and output. - max_act = 2 * (max(qkv, ffn) + residual) - # Size of output logits. - output_logits = 2 * (max_num_batched_tokens * self.vocab_size) - max_act = max(max_act, output_logits) - dtype_size = get_dtype_size(self.dtype) - return dtype_size * max_act diff --git a/cacheflow/model_executor/model_loader.py b/cacheflow/model_executor/model_loader.py index a89fe758..002f8405 100644 --- a/cacheflow/model_executor/model_loader.py +++ b/cacheflow/model_executor/model_loader.py @@ -5,9 +5,6 @@ import torch import torch.nn as nn from transformers import AutoConfig, PretrainedConfig -from cacheflow.model_executor.memory_analyzer import ( - CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer, - LlamaMemoryAnalyzer, OPTMemoryAnalyzer) from cacheflow.model_executor.models import ( GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM) from cacheflow.model_executor.utils import get_torch_dtype @@ -22,14 +19,6 @@ _MODEL_REGISTRY = { "OPTForCausalLM": OPTForCausalLM, } -_MEMORY_ANALYZERS = { - "GPT2LMHeadModel": GPT2MemoryAnalyzer, - "GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer, - "LlamaForCausalLM": LlamaMemoryAnalyzer, - "OPTForCausalLM": OPTMemoryAnalyzer, -} - - def _get_model_architecture(config: PretrainedConfig) -> nn.Module: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -41,17 +30,6 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module: ) -def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _MEMORY_ANALYZERS: - return _MEMORY_ANALYZERS[arch] - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}" - ) - - def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. @@ -100,18 +78,3 @@ def get_model( model = model.cuda() return model.eval(), torch_dtype - -def get_memory_analyzer( - model_name: str, - block_size: int, - dtype: str, - gpu_memory: int, - cpu_memory: int, - tensor_parallel_size: int = 1, -) -> CacheFlowMemoryAnalyzer: - config = AutoConfig.from_pretrained(model_name) - torch_dtype = _get_dtype(config, dtype) - memory_analyzer = _get_memory_analyzer(config) - return memory_analyzer( - model_name, block_size, torch_dtype, gpu_memory, cpu_memory, - tensor_parallel_size) diff --git a/cacheflow/model_executor/models/gpt2.py b/cacheflow/model_executor/models/gpt2.py index 4810f385..16a16d32 100644 --- a/cacheflow/model_executor/models/gpt2.py +++ b/cacheflow/model_executor/models/gpt2.py @@ -58,7 +58,8 @@ class GPT2Attention(nn.Module): self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True, input_is_parallel=True, perform_initialization=False) - self.attn = GPTCacheFlowAttention(scale=self.scale) + self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim, + scale=self.scale) def forward( self, diff --git a/cacheflow/model_executor/models/gpt_neox.py b/cacheflow/model_executor/models/gpt_neox.py index 916ba3f0..10125637 100644 --- a/cacheflow/model_executor/models/gpt_neox.py +++ b/cacheflow/model_executor/models/gpt_neox.py @@ -62,7 +62,8 @@ class GPTNeoXAttention(nn.Module): scaling = self.head_size ** -0.5 rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 - self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim) + self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_size, + scaling, rotary_dim) def forward( self, diff --git a/cacheflow/model_executor/models/llama.py b/cacheflow/model_executor/models/llama.py index 04699cad..9a55ef06 100644 --- a/cacheflow/model_executor/models/llama.py +++ b/cacheflow/model_executor/models/llama.py @@ -104,7 +104,8 @@ class LlamaAttention(nn.Module): input_is_parallel=True, perform_initialization=False, ) - self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim) + self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_dim, + self.scaling, rotary_dim=self.head_dim) def forward( self, diff --git a/cacheflow/model_executor/models/opt.py b/cacheflow/model_executor/models/opt.py index e51abe84..eeaa77a6 100644 --- a/cacheflow/model_executor/models/opt.py +++ b/cacheflow/model_executor/models/opt.py @@ -74,7 +74,8 @@ class OPTAttention(nn.Module): self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) - self.attn = GPTCacheFlowAttention(scale=self.scaling) + self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim, + scale=self.scaling) def forward( self, diff --git a/cacheflow/model_executor/utils.py b/cacheflow/model_executor/utils.py index ae7810ca..d4fe96fe 100644 --- a/cacheflow/model_executor/utils.py +++ b/cacheflow/model_executor/utils.py @@ -40,3 +40,15 @@ def set_random_seed(seed: int) -> None: if model_parallel_is_initialized(): model_parallel_cuda_manual_seed(seed) + + +def get_cache_block_size(block_size: int, + num_heads: int, + head_size: int, + num_layers: int, + dtype: str) -> int: + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(dtype) + return dtype_size * total diff --git a/cacheflow/worker/controller.py b/cacheflow/worker/controller.py index 018259b8..e3b14037 100644 --- a/cacheflow/worker/controller.py +++ b/cacheflow/worker/controller.py @@ -23,23 +23,18 @@ class Controller: pipeline_parallel_size: int, distributed_init_method: str, model_name: str, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, dtype: str, seed: int, cache_dir: Optional[str], use_dummy_weights: bool, use_np_cache: bool, max_num_batched_tokens: int, + max_num_sequences: int, use_ray: bool, ) -> None: self.stage_id = stage_id self.stage_devices = stage_devices self.model_name = model_name - self.block_size = block_size - self.num_gpu_blocks = num_gpu_blocks - self.num_cpu_blocks = num_cpu_blocks self.use_ray = use_ray # Which pipeline stage is this node assigned to? @@ -56,9 +51,6 @@ class Controller: worker_cls = Worker worker = worker_cls( model_name=model_name, - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, dtype=dtype, seed=seed, distributed_init_method=distributed_init_method, @@ -70,9 +62,44 @@ class Controller: use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache, max_num_batched_tokens=max_num_batched_tokens, + max_num_sequences=max_num_sequences, ) self.workers.append(worker) + def get_num_available_blocks(self, block_size: int, cpu_swap_space: int, + gpu_memory_utilization: float) -> List[Tuple[int, int]]: + all_worker_results = [] + for worker in self.workers: + executor = worker.get_num_available_blocks + if self.use_ray: + executor = executor.remote + + result = executor( + block_size, + cpu_swap_space, + gpu_memory_utilization, + ) + all_worker_results.append(result) + if self.use_ray: + all_worker_results = ray.get(all_worker_results) + return all_worker_results + + def init_cache_engine(self, block_size: int, num_gpu_blocks: int, + num_cpu_blocks: int): + all_worker_futures = [] + for worker in self.workers: + executor = worker.init_cache_engine + if self.use_ray: + executor = executor.remote + future = executor( + block_size, + num_gpu_blocks, + num_cpu_blocks, + ) + all_worker_futures.append(future) + if self.use_ray: + ray.get(all_worker_futures) + def set_next( self, next_node: Union['Controller', 'Scheduler'], diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index c98cf225..58f13c6b 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -3,7 +3,8 @@ from typing import Dict, List, Optional, Tuple import torch -from cacheflow.model_executor import get_model, InputMetadata, set_random_seed +from cacheflow.model_executor import (get_model, get_cache_block_size, + InputMetadata, set_random_seed) from cacheflow.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel, initialize_all_reduce_launcher, @@ -12,6 +13,7 @@ from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import (SequenceData, SequenceGroupMetadata, SequenceOutputs) from cacheflow.worker.cache_engine import CacheEngine +from cacheflow.utils import get_gpu_memory class Worker: @@ -25,9 +27,6 @@ class Worker: def __init__( self, model_name: str, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, dtype: str, seed: int, distributed_init_method: str, @@ -37,6 +36,7 @@ class Worker: use_dummy_weights: bool, use_np_cache: bool, max_num_batched_tokens: int, + max_num_sequences: int, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, ) -> None: @@ -46,8 +46,8 @@ class Worker: tensor_parallel_size, pipeline_parallel_size) self.worker_id = rank - self.block_size = block_size - set_random_seed(seed) + self.seed = seed + set_random_seed(self.seed) # Initialize the model. self.model, self.dtype = get_model( @@ -55,8 +55,10 @@ class Worker: use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache) tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) + self.max_num_batched_tokens = max_num_batched_tokens initialize_all_reduce_launcher( - max_num_batched_tokens, self.model.config.hidden_size, self.dtype) + self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype) + self.max_num_sequences = max_num_sequences self.num_layers = self.model.config.num_hidden_layers assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0 self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size @@ -66,12 +68,80 @@ class Worker: # the random state is not affected by the model initialization. set_random_seed(seed) + # Uninitialized cache engine. Will be initialized with + # self.init_cache_engine(). + self.block_size = None + self.cache_engine = None + self.cache_events = None + self.gpu_cache = None + + @torch.inference_mode() + def get_num_available_blocks( + self, block_size: int, cpu_swap_space: int, + gpu_memory_utilization: float) -> Tuple[int, int]: + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, + top_k=self.model.config.vocab_size - 1) + seqs = [] + for group_id in range(self.max_num_sequences): + seq_len = (self.max_num_batched_tokens // self.max_num_sequences + + (group_id < self.max_num_batched_tokens % + self.max_num_sequences)) + seq_data = SequenceData([0] * seq_len) + seq = SequenceGroupMetadata( + group_id=group_id, + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + ) + seqs.append(seq) + + input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs) + + # Execute the model. + self.model( + input_ids=input_tokens, + positions=input_positions, + kv_caches=[(None, None)] * self.num_layers, + input_metadata=input_metadata, + cache_events=None, + ) + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + peak_memory = torch.cuda.max_memory_allocated() + total_gpu_memory = get_gpu_memory() + cache_block_size = get_cache_block_size(block_size, self.num_heads, + self.head_size, self.num_layers, + self.dtype) + num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization + - peak_memory) // cache_block_size) + num_cpu_blocks = int(cpu_swap_space // cache_block_size) + torch.cuda.empty_cache() + # Reset the seed to ensure that the model output is not affected by + # the profiling. + set_random_seed(self.seed) + return num_gpu_blocks, num_cpu_blocks + + def init_cache_engine(self, block_size: int, num_gpu_blocks: int, + num_cpu_blocks: int): + self.block_size = block_size self.cache_engine = CacheEngine( worker_id=self.worker_id, num_layers=self.num_layers, num_heads=self.num_heads, head_size=self.head_size, - block_size=block_size, + block_size=self.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, dtype=self.dtype, @@ -129,6 +199,12 @@ class Worker: # is always the first token in the sequence. input_positions.extend(range(len(prompt_tokens))) + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.extend([0] * prompt_len) + continue + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] for i in range(prompt_len):