Fix cache engine
This commit is contained in:
parent
5a309bb588
commit
bb59a3e730
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,34 +14,30 @@ class CacheEngine:
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.worker_id = worker_id
|
||||
self.gpu_id = gpu_id
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.block_size = block_size
|
||||
self.dtype = dtype
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self.allocate_gpu_cache()
|
||||
self.cpu_cache = self.allocate_cpu_cache()
|
||||
|
||||
# Initialize the streams.
|
||||
self.copy_stream = torch.cuda.Stream(device=gpu_id)
|
||||
self.swap_stream = torch.cuda.Stream(device=gpu_id)
|
||||
assert self.copy_stream != self.swap_stream
|
||||
current_stream = torch.cuda.current_stream(device=gpu_id)
|
||||
assert self.copy_stream != current_stream
|
||||
assert self.swap_stream != current_stream
|
||||
|
||||
# Initialize the events for synchronization.
|
||||
# Initialize the stream for caching operations.
|
||||
self.cache_stream = torch.cuda.Stream(device=gpu_id)
|
||||
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||
|
||||
def allocate_gpu_cache(self) -> List[List[KVCache]]:
|
||||
gpu_cache: List[List[KVCache]] = []
|
||||
@ -81,29 +77,14 @@ class CacheEngine:
|
||||
cpu_cache.append(layer_cache)
|
||||
return cpu_cache
|
||||
|
||||
def copy(
|
||||
self,
|
||||
src_block_numbers: List[int],
|
||||
dst_block_numbers: List[int],
|
||||
) -> None:
|
||||
for layer in range(self.num_layers):
|
||||
# TODO: Call the COPY op.
|
||||
def copy(self, src_to_dst: Dict[int, int]) -> None:
|
||||
for event in self.events:
|
||||
pass
|
||||
|
||||
def swap_out(
|
||||
self,
|
||||
gpu_block_numbers: List[int],
|
||||
cpu_block_numbers: List[int],
|
||||
) -> None:
|
||||
for layer in range(self.num_layers):
|
||||
# TODO: Call the SWAP_OUT op on the swap stream.
|
||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||
for event in self.events:
|
||||
pass
|
||||
|
||||
def swap_in(
|
||||
self,
|
||||
gpu_block_numbers: List[int],
|
||||
cpu_block_numbers: List[int],
|
||||
) -> None:
|
||||
for layer in range(self.num_layers):
|
||||
# TODO: Call the SWAP_IN op on the swap stream.
|
||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||
for event in self.events:
|
||||
pass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user