Support YaRN models (#1264)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com> Co-authored-by: Viktor Ferenczi <viktor@ferenczi.eu> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
555bdcc5a3
commit
9f669a9a7c
@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel(
|
|||||||
scalar_t* __restrict__ out, // [..., d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||||
const int d) {
|
const int d) {
|
||||||
const int token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||||
out[token_idx * d + idx] = silu(x) * y;
|
out[token_idx * d + idx] = silu(x) * y;
|
||||||
@ -30,7 +30,7 @@ void silu_and_mul(
|
|||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
int num_tokens = input.numel() / input.size(-1);
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
@ -55,8 +55,8 @@ __global__ void activation_kernel(
|
|||||||
scalar_t* __restrict__ out, // [..., d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [..., d]
|
const scalar_t* __restrict__ input, // [..., d]
|
||||||
const int d) {
|
const int d) {
|
||||||
const int token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||||
out[token_idx * d + idx] = ACT_FN(x);
|
out[token_idx * d + idx] = ACT_FN(x);
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ __global__ void activation_kernel(
|
|||||||
// Launch element-wise activation kernel.
|
// Launch element-wise activation kernel.
|
||||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
int d = input.size(-1); \
|
int d = input.size(-1); \
|
||||||
int num_tokens = input.numel() / d; \
|
int64_t num_tokens = input.numel() / d; \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
|
|||||||
@ -84,7 +84,7 @@ void rotary_embedding(
|
|||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
int num_tokens = query.numel() / query.size(-1);
|
int64_t num_tokens = query.numel() / query.size(-1);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(-1) / head_size;
|
int num_heads = query.size(-1) / head_size;
|
||||||
int num_kv_heads = key.size(-1) / head_size;
|
int num_kv_heads = key.size(-1) / head_size;
|
||||||
|
|||||||
@ -390,6 +390,9 @@ def _get_and_verify_max_len(
|
|||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
assert "factor" in rope_scaling
|
assert "factor" in rope_scaling
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if rope_scaling["type"] == "yarn":
|
||||||
|
derived_max_model_len = rope_scaling[
|
||||||
|
"original_max_position_embeddings"]
|
||||||
derived_max_model_len *= scaling_factor
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm import cache_ops
|
|||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||||
RotaryEmbedding)
|
RotaryEmbedding, YaRNScalingRotaryEmbedding)
|
||||||
|
|
||||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||||
@ -334,6 +334,19 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
scaling_factor)
|
scaling_factor)
|
||||||
|
elif scaling_type == "yarn":
|
||||||
|
original_max_position = rope_scaling[
|
||||||
|
"original_max_position_embeddings"]
|
||||||
|
assert max_position == original_max_position * scaling_factor
|
||||||
|
extra_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in rope_scaling.items()
|
||||||
|
if k in ("extrapolation_factor", "attn_factor",
|
||||||
|
"beta_fast", "beta_slow")
|
||||||
|
}
|
||||||
|
self.rotary_emb = YaRNScalingRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, original_max_position, base,
|
||||||
|
is_neox_style, scaling_factor, **extra_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Rotary Positional Embeddings."""
|
"""Rotary Positional Embeddings."""
|
||||||
|
import math
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -167,3 +168,106 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
|
def _yarn_find_correction_dim(num_rotations: int,
|
||||||
|
dim: int,
|
||||||
|
base: float = 10000,
|
||||||
|
max_position_embeddings: int = 2048) -> float:
|
||||||
|
return (dim * math.log(max_position_embeddings /
|
||||||
|
(num_rotations * 2 * math.pi))) / (2 *
|
||||||
|
math.log(base))
|
||||||
|
|
||||||
|
|
||||||
|
# Find dim range bounds based on rotations
|
||||||
|
def _yarn_find_correction_range(low_rot: int,
|
||||||
|
high_rot: int,
|
||||||
|
dim: int,
|
||||||
|
base: float = 10000,
|
||||||
|
max_position_embeddings: int = 2048) -> int:
|
||||||
|
low = math.floor(
|
||||||
|
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||||
|
high = math.ceil(
|
||||||
|
_yarn_find_correction_dim(high_rot, dim, base,
|
||||||
|
max_position_embeddings))
|
||||||
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||||
|
|
||||||
|
|
||||||
|
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device) -> torch.Tensor:
|
||||||
|
if low == high:
|
||||||
|
high += 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
linear_func = (torch.arange(dim, dtype=dtype, device=device) -
|
||||||
|
low) / (high - low)
|
||||||
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
|
def _yarn_get_mscale(scale: float = 1) -> float:
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with YaRN method.
|
||||||
|
|
||||||
|
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
*,
|
||||||
|
extrapolation_factor: float = 1,
|
||||||
|
attn_factor: float = 1,
|
||||||
|
beta_fast: float = 32,
|
||||||
|
beta_slow: float = 1,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.extrapolation_factor = extrapolation_factor
|
||||||
|
self.attn_factor = attn_factor
|
||||||
|
self.beta_fast = beta_fast
|
||||||
|
self.beta_slow = beta_slow
|
||||||
|
# Get n-d magnitude scaling corrected for interpolation
|
||||||
|
self.mscale = float(
|
||||||
|
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style)
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
|
pos_freqs = self.base**(torch.arange(
|
||||||
|
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||||
|
self.rotary_dim)
|
||||||
|
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||||
|
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||||
|
|
||||||
|
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||||
|
self.rotary_dim, self.base,
|
||||||
|
self.max_position_embeddings)
|
||||||
|
# Get n-d rotational scaling corrected for extrapolation
|
||||||
|
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||||
|
low, high, self.rotary_dim // 2, dtype=torch.float,
|
||||||
|
device="cuda")) * self.extrapolation_factor
|
||||||
|
inv_freq = inv_freq_interpolation * (
|
||||||
|
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||||
|
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = (freqs.cos() * self.mscale)
|
||||||
|
sin = (freqs.sin() * self.mscale)
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user