284 lines
9.5 KiB
Python
284 lines
9.5 KiB
Python
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
try:
|
|
import intel_extension_for_pytorch as ipex
|
|
except ImportError as e:
|
|
logger.warning("Import error msg: %s", e.msg)
|
|
|
|
|
|
class ipex_ops:
|
|
|
|
@staticmethod
|
|
def _reshape_activation_tensor(
|
|
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
num = x.size(0)
|
|
d = x.size(1) // 2
|
|
x = x.reshape(num, 2, d)
|
|
x1, x2 = torch.chunk(x, chunks=2, dim=1)
|
|
x1 = x1.reshape(num, d)
|
|
x2 = x2.reshape(num, d)
|
|
return x1, x2
|
|
|
|
@staticmethod
|
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
|
ipex.llm.functional.silu_mul(x1, x2, out)
|
|
|
|
@staticmethod
|
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
|
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
|
|
|
|
@staticmethod
|
|
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
|
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
|
|
|
|
@staticmethod
|
|
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
out.copy_(torch.nn.functional.gelu(x))
|
|
|
|
@staticmethod
|
|
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
out.copy_(torch.nn.functional.gelu(x))
|
|
|
|
# TODO add implementation of gelu_quick here
|
|
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
|
@staticmethod
|
|
def paged_attention_v1(
|
|
out: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
num_kv_heads: int,
|
|
scale: float,
|
|
block_tables: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_size: int,
|
|
max_context_len: int,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
kv_cache_dtype: str,
|
|
k_scale: float,
|
|
v_scale: float,
|
|
tp_rank: int = 0,
|
|
blocksparse_local_blocks: int = 0,
|
|
blocksparse_vert_stride: int = 0,
|
|
blocksparse_block_size: int = 64,
|
|
blocksparse_head_sliding_step: int = 0,
|
|
) -> None:
|
|
assert kv_cache_dtype == "auto"
|
|
num_heads = out.size(1)
|
|
num_queries_per_tokens = num_heads // num_kv_heads
|
|
head_mapping = torch.arange(
|
|
0,
|
|
num_kv_heads,
|
|
device=query.device,
|
|
dtype=torch.int32,
|
|
).view(num_kv_heads,
|
|
1).repeat_interleave(num_queries_per_tokens).flatten()
|
|
# todo: ipex will refactor namespace
|
|
torch.xpu.paged_attention_v1( # type: ignore
|
|
out,
|
|
query.contiguous(),
|
|
key_cache.view_as(value_cache),
|
|
value_cache,
|
|
head_mapping,
|
|
scale,
|
|
block_tables,
|
|
context_lens,
|
|
block_size,
|
|
max_context_len,
|
|
alibi_slopes,
|
|
)
|
|
|
|
@staticmethod
|
|
def paged_attention_v2(
|
|
out: torch.Tensor,
|
|
exp_sum: torch.Tensor,
|
|
max_logits: torch.Tensor,
|
|
tmp_out: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
num_kv_heads: int,
|
|
scale: float,
|
|
block_tables: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_size: int,
|
|
max_context_len: int,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
kv_cache_dtype: str,
|
|
k_scale: float,
|
|
v_scale: float,
|
|
tp_rank: int = 0,
|
|
blocksparse_local_blocks: int = 0,
|
|
blocksparse_vert_stride: int = 0,
|
|
blocksparse_block_size: int = 64,
|
|
blocksparse_head_sliding_step: int = 0,
|
|
) -> None:
|
|
assert kv_cache_dtype == "auto"
|
|
num_heads = out.size(1)
|
|
num_queries_per_tokens = num_heads // num_kv_heads
|
|
head_mapping = torch.arange(
|
|
0,
|
|
num_kv_heads,
|
|
dtype=torch.int32,
|
|
device=query.device,
|
|
).view(num_kv_heads,
|
|
1).repeat_interleave(num_queries_per_tokens).flatten()
|
|
# todo: ipex will refactor namespace
|
|
torch.xpu.paged_attention_v2( # type: ignore
|
|
out,
|
|
exp_sum,
|
|
max_logits,
|
|
tmp_out,
|
|
query.contiguous(),
|
|
key_cache.view_as(value_cache),
|
|
value_cache,
|
|
head_mapping,
|
|
block_tables,
|
|
context_lens,
|
|
scale,
|
|
block_size,
|
|
max_context_len,
|
|
alibi_slopes,
|
|
)
|
|
|
|
@staticmethod
|
|
def rotary_embedding(
|
|
positions: torch.Tensor, # [batch_size, seq_len]
|
|
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
|
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
|
|
head_size: int,
|
|
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
|
is_neox: bool,
|
|
) -> None:
|
|
if positions.dim() == 1:
|
|
positions = positions.unsqueeze(0)
|
|
query = query.unsqueeze(0)
|
|
key = key.unsqueeze(0)
|
|
|
|
rotary_dim = cos_sin_cache.size(1)
|
|
query = query.view(*query.shape[:-1], -1, head_size)
|
|
key = key.view(*key.shape[:-1], -1, head_size)
|
|
|
|
query_rot = query[..., :rotary_dim]
|
|
key_rot = key[..., :rotary_dim]
|
|
|
|
cos_sin = cos_sin_cache[positions.long()]
|
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
|
|
if is_neox:
|
|
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
|
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
|
else:
|
|
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
|
rotary_dim, is_neox, positions)
|
|
|
|
@staticmethod
|
|
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|
key: torch.Tensor, head_size: int,
|
|
cos_sin_cache: torch.Tensor, is_neox: bool,
|
|
rot_dim: int,
|
|
cos_sin_cache_offsets: torch.Tensor) -> None:
|
|
if positions.dim() == 1:
|
|
positions = positions.unsqueeze(0)
|
|
query = query.unsqueeze(0)
|
|
key = key.unsqueeze(0)
|
|
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
|
|
rotary_dim = cos_sin_cache.size(1)
|
|
query = query.view(*query.shape[:-1], -1, head_size)
|
|
key = key.view(*key.shape[:-1], -1, head_size)
|
|
|
|
query_rot = query[..., :rotary_dim]
|
|
key_rot = key[..., :rotary_dim]
|
|
|
|
cos_sin = cos_sin_cache[torch.add(positions,
|
|
cos_sin_cache_offsets).long()]
|
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
|
|
if is_neox:
|
|
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
|
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
|
else:
|
|
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
|
|
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
|
rotary_dim, is_neox, positions)
|
|
|
|
@staticmethod
|
|
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
|
epsilon: float) -> None:
|
|
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
|
|
out.copy_(tmp)
|
|
|
|
@staticmethod
|
|
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
|
weight: torch.Tensor, epsilon: float) -> None:
|
|
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
|
|
epsilon, True)
|
|
input.copy_(tmp)
|
|
|
|
@staticmethod
|
|
def varlen_attention(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
out: torch.Tensor,
|
|
seqlen_q: torch.Tensor,
|
|
seqlen_k: torch.Tensor,
|
|
max_seqlen_q: int,
|
|
max_seqlen_k: int,
|
|
pdropout: float,
|
|
softmax_scale: float,
|
|
zero_tensors: bool,
|
|
is_causal: bool,
|
|
return_softmax: bool,
|
|
gen_: torch.Generator,
|
|
) -> None:
|
|
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
|
|
seqlen_k, max_seqlen_q,
|
|
max_seqlen_k, pdropout,
|
|
softmax_scale, zero_tensors,
|
|
is_causal, return_softmax, gen_)
|
|
|
|
@staticmethod
|
|
def reshape_and_cache(
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache_dtype: str,
|
|
k_scale: float,
|
|
v_scale: float,
|
|
) -> None:
|
|
assert kv_cache_dtype == "auto"
|
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
key, value, key_cache, value_cache, slot_mapping)
|
|
|
|
@staticmethod
|
|
def copy_blocks(key_caches: List[torch.Tensor],
|
|
value_caches: List[torch.Tensor],
|
|
block_mapping: torch.Tensor) -> None:
|
|
torch.xpu.copy_blocks( # type: ignore
|
|
key_caches,
|
|
value_caches,
|
|
block_mapping,
|
|
)
|
|
|
|
@staticmethod
|
|
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
|
block_mapping: torch.Tensor) -> None:
|
|
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|