Fix cache engine
This commit is contained in:
parent
5a309bb588
commit
bb59a3e730
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -14,34 +14,30 @@ class CacheEngine:
|
|||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
block_size: int,
|
dtype: torch.dtype,
|
||||||
dtype: torch.dtype = torch.float16,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.worker_id = worker_id
|
self.worker_id = worker_id
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
self.block_size = block_size
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
self.block_size = block_size
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
# Initialize the cache.
|
# Initialize the cache.
|
||||||
self.gpu_cache = self.allocate_gpu_cache()
|
self.gpu_cache = self.allocate_gpu_cache()
|
||||||
self.cpu_cache = self.allocate_cpu_cache()
|
self.cpu_cache = self.allocate_cpu_cache()
|
||||||
|
|
||||||
# Initialize the streams.
|
# Initialize the stream for caching operations.
|
||||||
self.copy_stream = torch.cuda.Stream(device=gpu_id)
|
self.cache_stream = torch.cuda.Stream(device=gpu_id)
|
||||||
self.swap_stream = torch.cuda.Stream(device=gpu_id)
|
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
|
||||||
assert self.copy_stream != self.swap_stream
|
# Initialize the events for stream synchronization.
|
||||||
current_stream = torch.cuda.current_stream(device=gpu_id)
|
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||||
assert self.copy_stream != current_stream
|
|
||||||
assert self.swap_stream != current_stream
|
|
||||||
|
|
||||||
# Initialize the events for synchronization.
|
|
||||||
|
|
||||||
def allocate_gpu_cache(self) -> List[List[KVCache]]:
|
def allocate_gpu_cache(self) -> List[List[KVCache]]:
|
||||||
gpu_cache: List[List[KVCache]] = []
|
gpu_cache: List[List[KVCache]] = []
|
||||||
@ -81,29 +77,14 @@ class CacheEngine:
|
|||||||
cpu_cache.append(layer_cache)
|
cpu_cache.append(layer_cache)
|
||||||
return cpu_cache
|
return cpu_cache
|
||||||
|
|
||||||
def copy(
|
def copy(self, src_to_dst: Dict[int, int]) -> None:
|
||||||
self,
|
for event in self.events:
|
||||||
src_block_numbers: List[int],
|
|
||||||
dst_block_numbers: List[int],
|
|
||||||
) -> None:
|
|
||||||
for layer in range(self.num_layers):
|
|
||||||
# TODO: Call the COPY op.
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def swap_out(
|
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||||
self,
|
for event in self.events:
|
||||||
gpu_block_numbers: List[int],
|
|
||||||
cpu_block_numbers: List[int],
|
|
||||||
) -> None:
|
|
||||||
for layer in range(self.num_layers):
|
|
||||||
# TODO: Call the SWAP_OUT op on the swap stream.
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def swap_in(
|
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||||
self,
|
for event in self.events:
|
||||||
gpu_block_numbers: List[int],
|
|
||||||
cpu_block_numbers: List[int],
|
|
||||||
) -> None:
|
|
||||||
for layer in range(self.num_layers):
|
|
||||||
# TODO: Call the SWAP_IN op on the swap stream.
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user