diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 1605158..dadee32 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -78,10 +78,11 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply class ApplyRotaryEmbQKV_(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cos, sin): + def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): """ qkv: (batch_size, seqlen, 3, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional rotary_dim must be <= headdim Apply rotary embedding *inplace* to the first rotary_dim of q and k. """ @@ -91,19 +92,21 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): rotary_dim *= 2 assert rotary_dim <= headdim assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False) k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'), - rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, False) - ctx.save_for_backward(cos, sin) + rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), + rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False) + ctx.save_for_backward(cos, sin, cos_k, sin_k) return qkv @staticmethod def backward(ctx, dqkv): - cos, sin = ctx.saved_tensors + cos, sin, cos_k, sin_k = ctx.saved_tensors _, seqlen, _, _, headdim = dqkv.shape rotary_dim = cos.shape[-1] rotary_dim *= 2 @@ -111,9 +114,9 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True) dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'), - rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, True) - return dqkv, None, None + rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), + rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True) + return dqkv, None, None, None, None apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply @@ -134,15 +137,24 @@ class RotaryEmbedding(torch.nn.Module): """ - def __init__(self, dim: int, base=10000, *_, **__): + def __init__(self, dim: int, base=10000, scale_base=0, *_, **__): + """ + If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ super().__init__() # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) + self.scale_base = scale_base + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None + self.register_buffer("scale", scale) self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None def _update_cos_sin_cache(self, x, seqlen_offset=0): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim) @@ -157,8 +169,18 @@ class RotaryEmbedding(torch.nn.Module): # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, self.inv_freq) - self._cos_cached = torch.cos(freqs).to(x.dtype) - self._sin_cached = torch.sin(freqs).to(x.dtype) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(x.dtype) + self._sin_cached = torch.sin(freqs).to(x.dtype) + else: + power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2) / self.scale_base) + scale = self.scale ** rearrange(power, 's -> s 1') + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -166,5 +188,12 @@ class RotaryEmbedding(torch.nn.Module): token in the batch. """ self._update_cos_sin_cache(qkv, seqlen_offset) - return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:]) + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] + ) + else: + return apply_rotary_emb_qkv_( + qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], + self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:] + ) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index d26a5a3..997810a 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -36,11 +36,12 @@ def create_mixer_cls(config, layer_idx=None): softmax_scale /= float(layer_idx + 1) dwconv = getattr(config, 'attn_dwconv', False) rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) + rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0) use_flash_attn = getattr(config, 'use_flash_attn', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False) mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, softmax_scale=softmax_scale, causal=True, dwconv=dwconv, - rotary_emb_dim=rotary_emb_dim, + rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn) return mixer_cls diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 2909168..f52766c 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -283,6 +283,7 @@ class MHA(nn.Module): def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0, + rotary_emb_scale_base=0, fused_bias_fc=False, use_flash_attn=False, return_residual=False, checkpointing=False, device=None, dtype=None) -> None: """ @@ -308,7 +309,7 @@ class MHA(nn.Module): if self.rotary_emb_dim > 0: assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' assert RotaryEmbedding is not None, 'rotary_emb is not installed' - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim) + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base) if fused_bias_fc and FusedDenseTD is None: raise ImportError('fused_dense is not installed')