Fix KVCache shape
This commit is contained in:
parent
3363c27d19
commit
2f4887de77
@ -39,42 +39,53 @@ class CacheEngine:
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||
|
||||
def allocate_gpu_cache(self) -> List[List[KVCache]]:
|
||||
gpu_cache: List[List[KVCache]] = []
|
||||
def get_key_block_shape(self) -> Tuple[int, int, int, int, int]:
|
||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
x = 16 // element_size
|
||||
return (
|
||||
self.num_heads,
|
||||
self.head_size // x,
|
||||
self.block_size,
|
||||
x,
|
||||
)
|
||||
|
||||
def get_value_block_shape(self) -> Tuple[int, int, int, int]:
|
||||
return (
|
||||
self.num_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
)
|
||||
|
||||
def allocate_gpu_cache(self) -> List[KVCache]:
|
||||
gpu_cache: List[KVCache] = []
|
||||
for _ in range(self.num_layers):
|
||||
layer_cache: List[KVCache] = []
|
||||
for _ in range(self.num_heads):
|
||||
key_blocks = torch.empty(
|
||||
(self.num_gpu_blocks, self.block_size * self.head_size),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
(self.num_gpu_blocks, self.block_size * self.head_size),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
)
|
||||
layer_cache.append((key_blocks, value_blocks))
|
||||
gpu_cache.append(layer_cache)
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
)
|
||||
gpu_cache.append((key_blocks, value_blocks))
|
||||
return gpu_cache
|
||||
|
||||
def allocate_cpu_cache(self) -> List[List[KVCache]]:
|
||||
cpu_cache: List[List[KVCache]] = []
|
||||
def allocate_cpu_cache(self) -> List[KVCache]:
|
||||
cpu_cache: List[KVCache] = []
|
||||
for _ in range(self.num_layers):
|
||||
layer_cache: List[KVCache] = []
|
||||
for _ in range(self.num_heads):
|
||||
key_blocks = torch.empty(
|
||||
(self.num_cpu_blocks, self.block_size * self.head_size),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
(self.num_cpu_blocks, self.block_size * self.head_size),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
)
|
||||
layer_cache.append((key_blocks, value_blocks))
|
||||
cpu_cache.append(layer_cache)
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
)
|
||||
cpu_cache.append((key_blocks, value_blocks))
|
||||
return cpu_cache
|
||||
|
||||
def copy(self, src_to_dst: Dict[int, int]) -> None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user