vllm/vllm/_ipex_ops.py

284 lines
9.5 KiB
Python

from typing import List, Optional, Tuple
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
class ipex_ops:
@staticmethod
def _reshape_activation_tensor(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
num = x.size(0)
d = x.size(1) // 2
x = x.reshape(num, 2, d)
x1, x2 = torch.chunk(x, chunks=2, dim=1)
x1 = x1.reshape(num, d)
x2 = x2.reshape(num, d)
return x1, x2
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out)
@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@staticmethod
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
device=query.device,
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1( # type: ignore
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
dtype=torch.int32,
device=query.device,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2( # type: ignore
out,
exp_sum,
max_logits,
tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
head_size: int,
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
is_neox: bool,
) -> None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
cos_sin = cos_sin_cache[positions.long()]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
cos_sin = cos_sin_cache[torch.add(positions,
cos_sin_cache_offsets).long()]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)
@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True)
input.copy_(tmp)
@staticmethod
def varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
seqlen_q: torch.Tensor,
seqlen_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
pdropout: float,
softmax_scale: float,
zero_tensors: bool,
is_causal: bool,
return_softmax: bool,
gen_: torch.Generator,
) -> None:
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
seqlen_k, max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax, gen_)
@staticmethod
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
assert kv_cache_dtype == "auto"
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore