vllm/vllm/worker/cache_engine.py

117 lines
4.2 KiB
Python
Raw Normal View History

"""CacheEngine class for managing the KV cache."""
from typing import List
2023-02-09 19:28:02 +08:00
import torch
from vllm.attention import get_attn_backend
2023-06-17 18:07:40 +08:00
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
2024-03-25 22:59:47 +08:00
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
logger = init_logger(__name__)
2023-02-09 19:28:02 +08:00
class CacheEngine:
"""Manages the KV cache.
This class is responsible for initializing and managing the GPU and CPU KV
caches. It also provides methods for performing KV cache operations, such
as swapping and copying.
"""
2023-02-09 19:28:02 +08:00
def __init__(
self,
2023-05-21 04:06:59 +08:00
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
2023-02-09 19:28:02 +08:00
) -> None:
2023-05-21 04:06:59 +08:00
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
2023-05-21 04:06:59 +08:00
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks
2023-02-09 19:28:02 +08:00
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(
model_config.get_num_attention_heads(parallel_config),
self.head_size,
self.num_kv_heads,
model_config.get_sliding_window(),
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
)
2023-02-09 19:28:02 +08:00
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
kv_cache.append(
torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device))
return kv_cache
2023-02-16 15:47:03 +08:00
def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst)
2023-02-09 19:28:02 +08:00
def swap_out(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
2023-05-21 04:06:59 +08:00
@staticmethod
def get_cache_block_size(
cache_config: CacheConfig,
2023-05-21 04:06:59 +08:00
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
2023-09-24 08:38:43 +08:00
num_heads = model_config.get_num_kv_heads(parallel_config)
2023-05-21 04:06:59 +08:00
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
2023-05-21 04:06:59 +08:00
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = _get_dtype_size(dtype)
2023-05-21 04:06:59 +08:00
return dtype_size * total
def _get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()