2023-05-15 13:32:38 +08:00
|
|
|
"""Custom activation functions."""
|
2023-12-03 13:18:40 +08:00
|
|
|
import math
|
2023-11-19 09:56:47 +08:00
|
|
|
from typing import Optional
|
|
|
|
|
|
2023-04-02 15:30:17 +08:00
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
2023-12-03 13:18:40 +08:00
|
|
|
import torch.nn.functional as F
|
2023-04-02 15:30:17 +08:00
|
|
|
|
2024-04-11 06:33:30 +08:00
|
|
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|
|
|
|
get_tensor_model_parallel_world_size)
|
2024-06-06 00:18:19 +08:00
|
|
|
from vllm.model_executor.custom_op import CustomOp
|
2023-11-19 09:56:47 +08:00
|
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
2023-11-21 13:42:45 +08:00
|
|
|
from vllm.model_executor.utils import set_weight_attrs
|
2023-04-02 15:30:17 +08:00
|
|
|
|
|
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
class SiluAndMul(CustomOp):
|
2023-05-15 13:32:38 +08:00
|
|
|
"""An activation function for SwiGLU.
|
|
|
|
|
|
2023-10-17 08:48:42 +08:00
|
|
|
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
2023-04-02 15:30:17 +08:00
|
|
|
|
2023-07-04 02:31:55 +08:00
|
|
|
Shapes:
|
2024-03-21 05:46:05 +08:00
|
|
|
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
|
|
|
|
return: (num_tokens, d) or (batch_size, seq_len, d)
|
2023-07-04 02:31:55 +08:00
|
|
|
"""
|
2023-04-02 15:30:17 +08:00
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
2023-12-03 13:18:40 +08:00
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
|
|
d = x.shape[-1] // 2
|
|
|
|
|
return F.silu(x[..., :d]) * x[..., d:]
|
|
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
|
|
2023-10-17 08:48:42 +08:00
|
|
|
d = x.shape[-1] // 2
|
|
|
|
|
output_shape = (x.shape[:-1] + (d, ))
|
|
|
|
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
2023-11-24 08:31:19 +08:00
|
|
|
ops.silu_and_mul(out, x)
|
2023-04-02 15:30:17 +08:00
|
|
|
return out
|
2023-08-23 06:43:21 +08:00
|
|
|
|
2024-06-18 02:01:25 +08:00
|
|
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm._ipex_ops import ipex_ops as ops
|
|
|
|
|
|
|
|
|
|
d = x.shape[-1] // 2
|
|
|
|
|
output_shape = (x.shape[:-1] + (d, ))
|
|
|
|
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
|
|
|
|
ops.silu_and_mul(out, x)
|
|
|
|
|
return out
|
|
|
|
|
|
2023-08-23 06:43:21 +08:00
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
class GeluAndMul(CustomOp):
|
2024-02-22 12:17:52 +08:00
|
|
|
"""An activation function for GeGLU.
|
|
|
|
|
|
|
|
|
|
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
|
|
|
|
|
|
|
|
|
Shapes:
|
|
|
|
|
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
|
|
|
|
return: (batch_size, seq_len, d) or (num_tokens, d)
|
|
|
|
|
"""
|
|
|
|
|
|
2024-03-13 13:06:17 +08:00
|
|
|
def __init__(self, approximate: str = "none"):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.approximate = approximate
|
|
|
|
|
if approximate not in ("none", "tanh"):
|
|
|
|
|
raise ValueError(f"Unknown approximate mode: {approximate}")
|
|
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
2024-02-22 12:17:52 +08:00
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
|
|
d = x.shape[-1] // 2
|
2024-03-13 13:06:17 +08:00
|
|
|
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
2024-02-22 12:17:52 +08:00
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
|
|
2024-02-22 12:17:52 +08:00
|
|
|
d = x.shape[-1] // 2
|
|
|
|
|
output_shape = (x.shape[:-1] + (d, ))
|
|
|
|
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
2024-03-13 13:06:17 +08:00
|
|
|
if self.approximate == "none":
|
|
|
|
|
ops.gelu_and_mul(out, x)
|
|
|
|
|
elif self.approximate == "tanh":
|
|
|
|
|
ops.gelu_tanh_and_mul(out, x)
|
2024-02-22 12:17:52 +08:00
|
|
|
return out
|
|
|
|
|
|
2024-06-18 02:01:25 +08:00
|
|
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm._ipex_ops import ipex_ops as ops
|
|
|
|
|
|
|
|
|
|
d = x.shape[-1] // 2
|
|
|
|
|
output_shape = (x.shape[:-1] + (d, ))
|
|
|
|
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
|
|
|
|
if self.approximate == "none":
|
|
|
|
|
ops.gelu_and_mul(out, x)
|
|
|
|
|
elif self.approximate == "tanh":
|
|
|
|
|
ops.gelu_tanh_and_mul(out, x)
|
|
|
|
|
return out
|
|
|
|
|
|
2024-05-01 12:18:14 +08:00
|
|
|
def extra_repr(self) -> str:
|
|
|
|
|
return f'approximate={repr(self.approximate)}'
|
|
|
|
|
|
2024-02-22 12:17:52 +08:00
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
class NewGELU(CustomOp):
|
2023-08-23 06:43:21 +08:00
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
2023-12-03 13:18:40 +08:00
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
|
|
c = math.sqrt(2.0 / math.pi)
|
|
|
|
|
return 0.5 * x * (1.0 + torch.tanh(c *
|
|
|
|
|
(x + 0.044715 * torch.pow(x, 3.0))))
|
|
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
|
|
2023-10-17 08:48:42 +08:00
|
|
|
out = torch.empty_like(x)
|
2023-11-24 08:31:19 +08:00
|
|
|
ops.gelu_new(out, x)
|
2023-08-23 06:43:21 +08:00
|
|
|
return out
|
|
|
|
|
|
2024-06-18 02:01:25 +08:00
|
|
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm._ipex_ops import ipex_ops as ops
|
|
|
|
|
|
|
|
|
|
out = torch.empty_like(x)
|
|
|
|
|
ops.gelu_new(out, x)
|
|
|
|
|
return out
|
|
|
|
|
|
2023-08-23 06:43:21 +08:00
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
class FastGELU(CustomOp):
|
2023-08-23 06:43:21 +08:00
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
2023-12-03 13:18:40 +08:00
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
|
|
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
|
|
|
|
(1.0 + 0.044715 * x * x)))
|
|
|
|
|
|
2024-06-06 00:18:19 +08:00
|
|
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
|
|
2023-10-17 08:48:42 +08:00
|
|
|
out = torch.empty_like(x)
|
2023-11-24 08:31:19 +08:00
|
|
|
ops.gelu_fast(out, x)
|
2023-08-23 06:43:21 +08:00
|
|
|
return out
|
|
|
|
|
|
2024-06-18 02:01:25 +08:00
|
|
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
from vllm._ipex_ops import ipex_ops as ops
|
|
|
|
|
|
|
|
|
|
out = torch.empty_like(x)
|
|
|
|
|
ops.gelu_fast(out, x)
|
|
|
|
|
return out
|
|
|
|
|
|
2023-08-23 06:43:21 +08:00
|
|
|
|
2023-11-19 09:56:47 +08:00
|
|
|
class ScaledActivation(nn.Module):
|
|
|
|
|
"""An activation function with post-scale parameters.
|
|
|
|
|
|
|
|
|
|
This is used for some quantization methods like AWQ.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
act_module: nn.Module,
|
2023-11-21 13:42:45 +08:00
|
|
|
intermediate_size: int,
|
|
|
|
|
input_is_parallel: bool = True,
|
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
2023-11-19 09:56:47 +08:00
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.act = act_module
|
2023-11-21 15:56:48 +08:00
|
|
|
self.input_is_parallel = input_is_parallel
|
2023-11-21 13:42:45 +08:00
|
|
|
if input_is_parallel:
|
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
|
intermediate_size_per_partition = divide(intermediate_size,
|
|
|
|
|
tp_size)
|
|
|
|
|
else:
|
|
|
|
|
intermediate_size_per_partition = intermediate_size
|
|
|
|
|
if params_dtype is None:
|
|
|
|
|
params_dtype = torch.get_default_dtype()
|
2023-11-19 09:56:47 +08:00
|
|
|
self.scales = nn.Parameter(
|
2024-02-02 07:46:39 +08:00
|
|
|
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
2023-11-21 13:42:45 +08:00
|
|
|
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
2023-11-19 09:56:47 +08:00
|
|
|
|
2023-11-21 13:42:45 +08:00
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
2023-11-19 09:56:47 +08:00
|
|
|
return self.act(x) / self.scales
|
|
|
|
|
|
2023-11-21 13:42:45 +08:00
|
|
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
|
|
|
|
param_data = param.data
|
2023-11-21 15:56:48 +08:00
|
|
|
if self.input_is_parallel:
|
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
|
shard_size = param_data.shape[0]
|
|
|
|
|
start_idx = tp_rank * shard_size
|
|
|
|
|
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
2023-11-21 13:42:45 +08:00
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
|
|
|
param_data.copy_(loaded_weight)
|
|
|
|
|
|
2023-11-19 09:56:47 +08:00
|
|
|
|
2023-08-23 06:43:21 +08:00
|
|
|
_ACTIVATION_REGISTRY = {
|
|
|
|
|
"gelu": nn.GELU(),
|
|
|
|
|
"gelu_fast": FastGELU(),
|
|
|
|
|
"gelu_new": NewGELU(),
|
|
|
|
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
|
|
|
|
"relu": nn.ReLU(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2023-11-19 09:56:47 +08:00
|
|
|
def get_act_fn(
|
|
|
|
|
act_fn_name: str,
|
|
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
|
|
|
intermediate_size: Optional[int] = None,
|
2023-11-21 13:42:45 +08:00
|
|
|
input_is_parallel: bool = True,
|
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
2023-11-19 09:56:47 +08:00
|
|
|
) -> nn.Module:
|
2023-08-23 06:43:21 +08:00
|
|
|
"""Get an activation function by name."""
|
2023-11-19 09:56:47 +08:00
|
|
|
act_fn_name = act_fn_name.lower()
|
|
|
|
|
if act_fn_name not in _ACTIVATION_REGISTRY:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Activation function {act_fn_name!r} is not supported.")
|
|
|
|
|
|
|
|
|
|
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
2023-11-21 13:42:45 +08:00
|
|
|
if (quant_config is not None
|
|
|
|
|
and act_fn_name in quant_config.get_scaled_act_names()):
|
2023-11-21 03:58:01 +08:00
|
|
|
if intermediate_size is None:
|
|
|
|
|
raise ValueError("intermediate_size must be specified for scaled "
|
|
|
|
|
"activation functions.")
|
2023-11-21 13:42:45 +08:00
|
|
|
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
|
|
|
|
params_dtype)
|
2023-11-19 09:56:47 +08:00
|
|
|
return act_fn
|