[Hardware][TPU] Optimize KV cache swapping (#5878)
This commit is contained in:
parent
c3dde367f1
commit
f136da15e1
@ -28,21 +28,13 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_k_cache, src_v_cache = src_kv_cache
|
||||
dst_k_cache, dst_v_cache = dst_kv_cache
|
||||
src_indices, dst_indices = src_to_dst
|
||||
device = dst_k_cache.device
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
|
||||
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
|
||||
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -152,8 +153,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
|
||||
self.cpu_cache = []
|
||||
self.tpu_cache = []
|
||||
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
@ -227,18 +228,25 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
if blocks_to_swap_in:
|
||||
# Swap from CPU to TPU.
|
||||
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
|
||||
self.device)
|
||||
src_indices, dst_indices = _make_src_to_dst(
|
||||
blocks_to_swap_in, "cpu", self.device)
|
||||
for i in range(num_layers):
|
||||
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
|
||||
src_to_dst)
|
||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||
k = cpu_k_cache[:, src_indices].to(self.device)
|
||||
v = cpu_v_cache[:, src_indices].to(self.device)
|
||||
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
|
||||
|
||||
if blocks_to_swap_out:
|
||||
# Swap from TPU to CPU.
|
||||
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
|
||||
"cpu")
|
||||
src_indices, dst_indices = _make_src_to_dst(
|
||||
blocks_to_swap_out, self.device, "cpu")
|
||||
for i in range(num_layers):
|
||||
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
|
||||
src_to_dst)
|
||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
|
||||
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()
|
||||
|
||||
if blocks_to_copy:
|
||||
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
|
||||
self.device)
|
||||
@ -267,3 +275,17 @@ def _make_src_to_dst(
|
||||
device=dst_device,
|
||||
dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
def _insert_kv(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
tpu_k_cache: torch.Tensor,
|
||||
tpu_v_cache: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
|
||||
tpu_k_cache[:, indices] = k
|
||||
tpu_v_cache[:, indices] = v
|
||||
|
||||
Loading…
Reference in New Issue
Block a user