[Hardware][TPU] Optimize KV cache swapping (#5878)
This commit is contained in:
parent
c3dde367f1
commit
f136da15e1
@ -28,21 +28,13 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||||
|
|
||||||
@torch.compile(backend="openxla")
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
src_kv_cache: torch.Tensor,
|
||||||
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
src_k_cache, src_v_cache = src_kv_cache
|
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||||
dst_k_cache, dst_v_cache = dst_kv_cache
|
|
||||||
src_indices, dst_indices = src_to_dst
|
|
||||||
device = dst_k_cache.device
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
|
|
||||||
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
|
|
||||||
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
|
|
||||||
|
|
||||||
@torch.compile(backend="openxla")
|
@torch.compile(backend="openxla")
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
|
||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -152,8 +153,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
head_size = self.model_config.get_head_size()
|
head_size = self.model_config.get_head_size()
|
||||||
|
|
||||||
self.cpu_cache = []
|
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||||
self.tpu_cache = []
|
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||||
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||||
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||||
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||||
@ -227,18 +228,25 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
if blocks_to_swap_in:
|
if blocks_to_swap_in:
|
||||||
# Swap from CPU to TPU.
|
# Swap from CPU to TPU.
|
||||||
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
|
src_indices, dst_indices = _make_src_to_dst(
|
||||||
self.device)
|
blocks_to_swap_in, "cpu", self.device)
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
|
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||||
src_to_dst)
|
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||||
|
k = cpu_k_cache[:, src_indices].to(self.device)
|
||||||
|
v = cpu_v_cache[:, src_indices].to(self.device)
|
||||||
|
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
|
||||||
|
|
||||||
if blocks_to_swap_out:
|
if blocks_to_swap_out:
|
||||||
# Swap from TPU to CPU.
|
# Swap from TPU to CPU.
|
||||||
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
|
src_indices, dst_indices = _make_src_to_dst(
|
||||||
"cpu")
|
blocks_to_swap_out, self.device, "cpu")
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
|
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||||
src_to_dst)
|
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||||
|
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
|
||||||
|
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()
|
||||||
|
|
||||||
if blocks_to_copy:
|
if blocks_to_copy:
|
||||||
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
|
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
|
||||||
self.device)
|
self.device)
|
||||||
@ -267,3 +275,17 @@ def _make_src_to_dst(
|
|||||||
device=dst_device,
|
device=dst_device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
return src_indices, dst_indices
|
return src_indices, dst_indices
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
|
def _insert_kv(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
tpu_k_cache: torch.Tensor,
|
||||||
|
tpu_v_cache: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
|
||||||
|
tpu_k_cache[:, indices] = k
|
||||||
|
tpu_v_cache[:, indices] = v
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user