Implement XPos (Sun et al.)
This commit is contained in:
parent
c2407dec96
commit
496e4f528c
@ -78,10 +78,11 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
|
||||
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cos, sin):
|
||||
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
||||
"""
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
||||
"""
|
||||
@ -91,19 +92,21 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
||||
k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
||||
ctx.save_for_backward(cos, sin)
|
||||
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dqkv):
|
||||
cos, sin = ctx.saved_tensors
|
||||
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
||||
_, seqlen, _, _, headdim = dqkv.shape
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
@ -111,9 +114,9 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
||||
dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
||||
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
||||
return dqkv, None, None
|
||||
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
||||
return dqkv, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
||||
@ -134,15 +137,24 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, base=10000, *_, **__):
|
||||
def __init__(self, dim: int, base=10000, scale_base=0, *_, **__):
|
||||
"""
|
||||
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
super().__init__()
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.scale_base = scale_base
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None
|
||||
self.register_buffer("scale", scale)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _update_cos_sin_cache(self, x, seqlen_offset=0):
|
||||
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
|
||||
@ -157,8 +169,18 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
||||
else:
|
||||
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
||||
- seqlen // 2) / self.scale_base)
|
||||
scale = self.scale ** rearrange(power, 's -> s 1')
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -166,5 +188,12 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
token in the batch.
|
||||
"""
|
||||
self._update_cos_sin_cache(qkv, seqlen_offset)
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:])
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
|
||||
)
|
||||
else:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:]
|
||||
)
|
||||
|
||||
@ -36,11 +36,12 @@ def create_mixer_cls(config, layer_idx=None):
|
||||
softmax_scale /= float(layer_idx + 1)
|
||||
dwconv = getattr(config, 'attn_dwconv', False)
|
||||
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
|
||||
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0)
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
|
||||
softmax_scale=softmax_scale, causal=True, dwconv=dwconv,
|
||||
rotary_emb_dim=rotary_emb_dim,
|
||||
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
@ -283,6 +283,7 @@ class MHA(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
|
||||
softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0,
|
||||
rotary_emb_scale_base=0,
|
||||
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
|
||||
checkpointing=False, device=None, dtype=None) -> None:
|
||||
"""
|
||||
@ -308,7 +309,7 @@ class MHA(nn.Module):
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
|
||||
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim)
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base)
|
||||
|
||||
if fused_bias_fc and FusedDenseTD is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user