From e45a46a5b767d76e14c76e4bfac408b7cf94d896 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 14 Mar 2023 14:28:28 -0700 Subject: [PATCH] [Rotary] Implement GPT-J style (interleaved) rotary --- csrc/rotary/rotary.cpp | 4 ++ csrc/rotary/rotary_cuda.cu | 4 ++ flash_attn/layers/rotary.py | 87 +++++++++++++++++++-------- tests/layers/test_rotary.py | 114 ++++++++++++++++++++++++++++++++++++ 4 files changed, 183 insertions(+), 26 deletions(-) create mode 100644 tests/layers/test_rotary.py diff --git a/csrc/rotary/rotary.cpp b/csrc/rotary/rotary.cpp index 206fda3..b2a3cf0 100644 --- a/csrc/rotary/rotary.cpp +++ b/csrc/rotary/rotary.cpp @@ -1,3 +1,7 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + #include #include diff --git a/csrc/rotary/rotary_cuda.cu b/csrc/rotary/rotary_cuda.cu index 7b68d62..2dd0ff3 100644 --- a/csrc/rotary/rotary_cuda.cu +++ b/csrc/rotary/rotary_cuda.cu @@ -1,3 +1,7 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + #include #include #include diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 222f74e..437b3a7 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -1,4 +1,4 @@ -# Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py +# Copyright (c) 2023, Tri Dao. from typing import Tuple import math @@ -10,31 +10,37 @@ from einops import rearrange, repeat import rotary_emb -def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2) -def apply_rotary_emb_torch(x, cos, sin): +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) """ - rotary_dim = cos.shape[-1] * 2 - assert rotary_dim <= x.shape[-1] + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] cos = repeat(cos, 's d -> s 1 (2 d)') sin = repeat(sin, 's d -> s 1 (2 d)') - return torch.cat([x[..., :rotary_dim] * cos + rotate_half(x[..., :rotary_dim]) * sin, - x[..., rotary_dim:]], dim=-1) + return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:]], dim=-1) class ApplyRotaryEmb(torch.autograd.Function): @staticmethod - def forward(ctx, x, cos, sin, inplace=False): + def forward(ctx, x, cos, sin, interleaved=False, inplace=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). rotary_dim must be <= headdim Apply rotary embedding to the first rotary_dim of x. """ @@ -44,14 +50,21 @@ class ApplyRotaryEmb(torch.autograd.Function): assert rotary_dim <= headdim assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) + x_ro = x[..., :rotary_dim] + x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) out = torch.empty_like(x) if not inplace else x - o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2) + out_ro = out[..., :rotary_dim] + if inplace: + o1, o2 = x1, x2 + else: + o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved + else (out_ro[..., ::2], out_ro[..., 1::2])) rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False) if not inplace and rotary_dim < headdim: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved ctx.inplace = inplace return out if not inplace else x @@ -62,14 +75,21 @@ class ApplyRotaryEmb(torch.autograd.Function): rotary_dim = cos.shape[-1] rotary_dim *= 2 inplace = ctx.inplace - do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1) + do_ro = do[..., :rotary_dim] + do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved + else (do_ro[..., ::2], do_ro[..., 1::2])) dx = torch.empty_like(do) if not inplace else do - dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2) + if inplace: + dx1, dx2 = do1, do2 + else: + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved + else (dx_ro[..., ::2], dx_ro[..., 1::2])) rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True) if not inplace and rotary_dim < headdim: dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None + return dx, None, None, None, None apply_rotary_emb_func = ApplyRotaryEmb.apply @@ -78,11 +98,13 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply class ApplyRotaryEmbQKV_(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): + def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): """ qkv: (batch_size, seqlen, 3, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). rotary_dim must be <= headdim Apply rotary embedding *inplace* to the first rotary_dim of q and k. """ @@ -95,13 +117,16 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): cos_k = cos if cos_k is None else cos_k sin_k = sin if sin_k is None else sin_k assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) + q_ro = qkv[:, :, 0, :, :rotary_dim] + q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False) - k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) + k_ro = qkv[:, :, 1, :, :rotary_dim] + k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False) ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.interleaved = interleaved return qkv @staticmethod @@ -110,13 +135,17 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): _, seqlen, _, _, headdim = dqkv.shape rotary_dim = cos.shape[-1] rotary_dim *= 2 - dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) + dq_ro = dqkv[:, :, 0, :, :rotary_dim] + dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved + else (dq_ro[..., ::2], dq_ro[..., 1::2])) rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True) - dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) + dk_ro = dqkv[:, :, 1, :, :rotary_dim] + dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved + else (dk_ro[..., ::2], dk_ro[..., 1::2])) rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True) - return dqkv, None, None, None, None + return dqkv, None, None, None, None, None apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply @@ -135,22 +164,25 @@ class RotaryEmbedding(torch.nn.Module): .. _repo: https://github.com/ZhuiyiTechnology/roformer .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py """ - def __init__(self, dim: int, base=10000, scale_base=0, device=None): + def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None): """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). """ super().__init__() # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq) + self.interleaved = interleaved self.scale_base = scale_base scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) - / (1.4 * dim) if scale_base > 0 else None) + / (1.4 * dim) if scale_base is not None else None) self.register_buffer("scale", scale) self._seq_len_cached = 0 @@ -187,16 +219,19 @@ class RotaryEmbedding(torch.nn.Module): def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """ + qkv: (batch, seqlen, 3, nheads, headdim) seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch. """ self._update_cos_sin_cache(qkv, seqlen_offset) if self.scale is None: return apply_rotary_emb_qkv_( - qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] + qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], + None, None, self.interleaved ) else: return apply_rotary_emb_qkv_( qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], - self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:] + self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], + self.interleaved ) diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py new file mode 100644 index 0000000..f41595f --- /dev/null +++ b/tests/layers/test_rotary.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023, Tri Dao. + +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX +from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_neox +from transformers.models.gptj.modeling_gptj import fixed_pos_embedding +from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj + +from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_qkv_ +from flash_attn.layers.rotary import RotaryEmbedding + + +# NeoX-style rotary embedding +@pytest.mark.parametrize('seqlen_offset', [0, 711]) +@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0]) +def test_rotary(rotary_emb_fraction, seqlen_offset): + device = 'cuda' + dtype = torch.float16 + rtol, atol = (1e-3, 5e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen_total = 2048 + seqlen = seqlen_total - seqlen_offset + nheads = 16 + headdim = 128 + rotary_dim = int(headdim * rotary_emb_fraction) + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace + rotary = RotaryEmbedding(rotary_dim, device=device) + rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device) + # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor + cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total) + cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype) + q_pt = rearrange(qkv[:, :, 0, :, :rotary_dim], + 'b s h d -> b h s d').detach().clone().requires_grad_(True) + k_pt = rearrange(qkv[:, :, 1, :, :rotary_dim], + 'b s h d -> b h s d').detach().clone().requires_grad_(True) + q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset) + out = rotary(qkv, seqlen_offset=seqlen_offset) + assert torch.allclose(rotary._cos_cached, cos_neox[..., :rotary_dim // 2].to(dtype=dtype), + rtol=rtol, atol=atol) + assert torch.allclose(rotary._sin_cached, sin_neox[..., :rotary_dim // 2].to(dtype=dtype), + rtol=rtol, atol=atol) + assert torch.allclose(rearrange(q_neox, 'b h s d -> b s h d'), out[:, :, 0, :, :rotary_dim], + rtol=rtol, atol=atol) + assert torch.allclose(rearrange(k_neox, 'b h s d -> b s h d'), out[:, :, 1, :, :rotary_dim], + rtol=rtol, atol=atol) + assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:]) + assert torch.equal(out[:, :, 2], qkv_og[:, :, 2]) + + g = torch.randn_like(out) + g_og = g.clone().detach() # Our implementation modifies g inplace + out.backward(g) + q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], 'b s h d -> b h s d')) + k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], 'b s h d -> b h s d')) + assert torch.allclose(rearrange(q_pt.grad, 'b h s d -> b s h d'), + qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol) + assert torch.allclose(rearrange(k_pt.grad, 'b h s d -> b s h d'), + qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol) + assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:]) + assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2]) + + +# GPT-J-style rotary embedding +@pytest.mark.parametrize('seqlen_offset', [0, 711]) +@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0]) +def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset): + device = 'cuda' + dtype = torch.float16 + rtol, atol = (1e-3, 5e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen_total = 2048 + seqlen = seqlen_total - seqlen_offset + nheads = 16 + headdim = 128 + rotary_dim = int(headdim * rotary_emb_fraction) + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace + rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device) + sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total) + sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj) + q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True) + k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True) + q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset) + k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset) + out = rotary(qkv, seqlen_offset=seqlen_offset) + assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol) + assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol) + assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol) + assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol) + assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:]) + assert torch.equal(out[:, :, 2], qkv_og[:, :, 2]) + + g = torch.randn_like(out) + g_og = g.clone().detach() # Our implementation modifies g inplace + out.backward(g) + q_gptj.backward(g_og[:, :, 0, :, :rotary_dim]) + k_gptj.backward(g_og[:, :, 1, :, :rotary_dim]) + assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol) + assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol) + assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:]) + assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])