Fix KVCache shape

This commit is contained in:
Woosuk Kwon 2023-02-16 01:24:45 +00:00
parent 3363c27d19
commit 2f4887de77

View File

@ -39,42 +39,53 @@ class CacheEngine:
# Initialize the events for stream synchronization. # Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)] self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
def allocate_gpu_cache(self) -> List[List[KVCache]]: def get_key_block_shape(self) -> Tuple[int, int, int, int, int]:
gpu_cache: List[List[KVCache]] = [] element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
return (
self.num_heads,
self.head_size // x,
self.block_size,
x,
)
def get_value_block_shape(self) -> Tuple[int, int, int, int]:
return (
self.num_heads,
self.block_size,
self.head_size,
)
def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
for _ in range(self.num_layers): for _ in range(self.num_layers):
layer_cache: List[KVCache] = [] key_blocks = torch.empty(
for _ in range(self.num_heads): size=(self.num_gpu_blocks, *self.get_key_block_shape()),
key_blocks = torch.empty( dtype=self.dtype,
(self.num_gpu_blocks, self.block_size * self.head_size), device=self.gpu_id,
dtype=self.dtype, )
device=self.gpu_id, value_blocks = torch.empty(
) size=(self.num_gpu_blocks, *self.get_value_block_shape()),
value_blocks = torch.empty( dtype=self.dtype,
(self.num_gpu_blocks, self.block_size * self.head_size), device=self.gpu_id,
dtype=self.dtype, )
device=self.gpu_id, gpu_cache.append((key_blocks, value_blocks))
)
layer_cache.append((key_blocks, value_blocks))
gpu_cache.append(layer_cache)
return gpu_cache return gpu_cache
def allocate_cpu_cache(self) -> List[List[KVCache]]: def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[List[KVCache]] = [] cpu_cache: List[KVCache] = []
for _ in range(self.num_layers): for _ in range(self.num_layers):
layer_cache: List[KVCache] = [] key_blocks = torch.empty(
for _ in range(self.num_heads): size=(self.num_cpu_blocks, *self.get_key_block_shape()),
key_blocks = torch.empty( dtype=self.dtype,
(self.num_cpu_blocks, self.block_size * self.head_size), pin_memory=True,
dtype=self.dtype, )
pin_memory=True, value_blocks = torch.empty(
) size=(self.num_cpu_blocks, *self.get_value_block_shape()),
value_blocks = torch.empty( dtype=self.dtype,
(self.num_cpu_blocks, self.block_size * self.head_size), pin_memory=True,
dtype=self.dtype, )
pin_memory=True, cpu_cache.append((key_blocks, value_blocks))
)
layer_cache.append((key_blocks, value_blocks))
cpu_cache.append(layer_cache)
return cpu_cache return cpu_cache
def copy(self, src_to_dst: Dict[int, int]) -> None: def copy(self, src_to_dst: Dict[int, int]) -> None: