Fix KVCache shape

This commit is contained in:
Woosuk Kwon 2023-02-16 01:24:45 +00:00
parent 3363c27d19
commit 2f4887de77

View File

@ -39,42 +39,53 @@ class CacheEngine:
# Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
def allocate_gpu_cache(self) -> List[List[KVCache]]:
gpu_cache: List[List[KVCache]] = []
def get_key_block_shape(self) -> Tuple[int, int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
return (
self.num_heads,
self.head_size // x,
self.block_size,
x,
)
def get_value_block_shape(self) -> Tuple[int, int, int, int]:
return (
self.num_heads,
self.block_size,
self.head_size,
)
def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
for _ in range(self.num_layers):
layer_cache: List[KVCache] = []
for _ in range(self.num_heads):
key_blocks = torch.empty(
(self.num_gpu_blocks, self.block_size * self.head_size),
dtype=self.dtype,
device=self.gpu_id,
)
value_blocks = torch.empty(
(self.num_gpu_blocks, self.block_size * self.head_size),
dtype=self.dtype,
device=self.gpu_id,
)
layer_cache.append((key_blocks, value_blocks))
gpu_cache.append(layer_cache)
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
dtype=self.dtype,
device=self.gpu_id,
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
dtype=self.dtype,
device=self.gpu_id,
)
gpu_cache.append((key_blocks, value_blocks))
return gpu_cache
def allocate_cpu_cache(self) -> List[List[KVCache]]:
cpu_cache: List[List[KVCache]] = []
def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
for _ in range(self.num_layers):
layer_cache: List[KVCache] = []
for _ in range(self.num_heads):
key_blocks = torch.empty(
(self.num_cpu_blocks, self.block_size * self.head_size),
dtype=self.dtype,
pin_memory=True,
)
value_blocks = torch.empty(
(self.num_cpu_blocks, self.block_size * self.head_size),
dtype=self.dtype,
pin_memory=True,
)
layer_cache.append((key_blocks, value_blocks))
cpu_cache.append(layer_cache)
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
dtype=self.dtype,
pin_memory=True,
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
dtype=self.dtype,
pin_memory=True,
)
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache
def copy(self, src_to_dst: Dict[int, int]) -> None: