From 2f4887de77fc36ec27f6b2e8d4cd52c9cf02efed Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 Feb 2023 01:24:45 +0000 Subject: [PATCH] Fix KVCache shape --- cacheflow/worker/cache_engine.py | 75 ++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 2a9a180a..4a2fc350 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -39,42 +39,53 @@ class CacheEngine: # Initialize the events for stream synchronization. self.events = [torch.cuda.Event() for _ in range(self.num_layers)] - def allocate_gpu_cache(self) -> List[List[KVCache]]: - gpu_cache: List[List[KVCache]] = [] + def get_key_block_shape(self) -> Tuple[int, int, int, int, int]: + element_size = torch.tensor([], dtype=self.dtype).element_size() + x = 16 // element_size + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) + + def get_value_block_shape(self) -> Tuple[int, int, int, int]: + return ( + self.num_heads, + self.block_size, + self.head_size, + ) + + def allocate_gpu_cache(self) -> List[KVCache]: + gpu_cache: List[KVCache] = [] for _ in range(self.num_layers): - layer_cache: List[KVCache] = [] - for _ in range(self.num_heads): - key_blocks = torch.empty( - (self.num_gpu_blocks, self.block_size * self.head_size), - dtype=self.dtype, - device=self.gpu_id, - ) - value_blocks = torch.empty( - (self.num_gpu_blocks, self.block_size * self.head_size), - dtype=self.dtype, - device=self.gpu_id, - ) - layer_cache.append((key_blocks, value_blocks)) - gpu_cache.append(layer_cache) + key_blocks = torch.empty( + size=(self.num_gpu_blocks, *self.get_key_block_shape()), + dtype=self.dtype, + device=self.gpu_id, + ) + value_blocks = torch.empty( + size=(self.num_gpu_blocks, *self.get_value_block_shape()), + dtype=self.dtype, + device=self.gpu_id, + ) + gpu_cache.append((key_blocks, value_blocks)) return gpu_cache - def allocate_cpu_cache(self) -> List[List[KVCache]]: - cpu_cache: List[List[KVCache]] = [] + def allocate_cpu_cache(self) -> List[KVCache]: + cpu_cache: List[KVCache] = [] for _ in range(self.num_layers): - layer_cache: List[KVCache] = [] - for _ in range(self.num_heads): - key_blocks = torch.empty( - (self.num_cpu_blocks, self.block_size * self.head_size), - dtype=self.dtype, - pin_memory=True, - ) - value_blocks = torch.empty( - (self.num_cpu_blocks, self.block_size * self.head_size), - dtype=self.dtype, - pin_memory=True, - ) - layer_cache.append((key_blocks, value_blocks)) - cpu_cache.append(layer_cache) + key_blocks = torch.empty( + size=(self.num_cpu_blocks, *self.get_key_block_shape()), + dtype=self.dtype, + pin_memory=True, + ) + value_blocks = torch.empty( + size=(self.num_cpu_blocks, *self.get_value_block_shape()), + dtype=self.dtype, + pin_memory=True, + ) + cpu_cache.append((key_blocks, value_blocks)) return cpu_cache def copy(self, src_to_dst: Dict[int, int]) -> None: