From bb59a3e7302ad6892e097eee4040e3f516e9f4ea Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 Feb 2023 09:35:48 +0000 Subject: [PATCH] Fix cache engine --- cacheflow/worker/cache_engine.py | 49 ++++++++++---------------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 271a6782..2a9a180a 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Dict, List, Tuple import torch @@ -14,34 +14,30 @@ class CacheEngine: num_layers: int, num_heads: int, head_size: int, + block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, - block_size: int, - dtype: torch.dtype = torch.float16, + dtype: torch.dtype, ) -> None: self.worker_id = worker_id self.gpu_id = gpu_id self.num_layers = num_layers self.num_heads = num_heads self.head_size = head_size + self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks self.num_cpu_blocks = num_cpu_blocks - self.block_size = block_size self.dtype = dtype # Initialize the cache. self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() - # Initialize the streams. - self.copy_stream = torch.cuda.Stream(device=gpu_id) - self.swap_stream = torch.cuda.Stream(device=gpu_id) - assert self.copy_stream != self.swap_stream - current_stream = torch.cuda.current_stream(device=gpu_id) - assert self.copy_stream != current_stream - assert self.swap_stream != current_stream - - # Initialize the events for synchronization. + # Initialize the stream for caching operations. + self.cache_stream = torch.cuda.Stream(device=gpu_id) + assert self.cache_stream != torch.cuda.current_stream(device=gpu_id) + # Initialize the events for stream synchronization. + self.events = [torch.cuda.Event() for _ in range(self.num_layers)] def allocate_gpu_cache(self) -> List[List[KVCache]]: gpu_cache: List[List[KVCache]] = [] @@ -81,29 +77,14 @@ class CacheEngine: cpu_cache.append(layer_cache) return cpu_cache - def copy( - self, - src_block_numbers: List[int], - dst_block_numbers: List[int], - ) -> None: - for layer in range(self.num_layers): - # TODO: Call the COPY op. + def copy(self, src_to_dst: Dict[int, int]) -> None: + for event in self.events: pass - def swap_out( - self, - gpu_block_numbers: List[int], - cpu_block_numbers: List[int], - ) -> None: - for layer in range(self.num_layers): - # TODO: Call the SWAP_OUT op on the swap stream. + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + for event in self.events: pass - def swap_in( - self, - gpu_block_numbers: List[int], - cpu_block_numbers: List[int], - ) -> None: - for layer in range(self.num_layers): - # TODO: Call the SWAP_IN op on the swap stream. + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + for event in self.events: pass