[Rotary] Implement GPT-J style (interleaved) rotary

This commit is contained in:
Tri Dao 2023-03-14 14:28:28 -07:00
parent f28d61cb2a
commit e45a46a5b7
4 changed files with 183 additions and 26 deletions

View File

@ -1,3 +1,7 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>

View File

@ -1,3 +1,7 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <torch/python.h> #include <torch/python.h>
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh> #include <ATen/native/cuda/Loops.cuh>

View File

@ -1,4 +1,4 @@
# Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py # Copyright (c) 2023, Tri Dao.
from typing import Tuple from typing import Tuple
import math import math
@ -10,31 +10,37 @@ from einops import rearrange, repeat
import rotary_emb import rotary_emb
def rotate_half(x): def rotate_half(x, interleaved=False):
x1, x2 = x.chunk(2, dim=-1) if not interleaved:
return torch.cat((-x2, x1), dim=-1) x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
def apply_rotary_emb_torch(x, cos, sin): def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
""" """
x: (batch_size, seqlen, nheads, headdim) x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) cos, sin: (seqlen, rotary_dim / 2)
""" """
rotary_dim = cos.shape[-1] * 2 ro_dim = cos.shape[-1] * 2
assert rotary_dim <= x.shape[-1] assert ro_dim <= x.shape[-1]
cos = repeat(cos, 's d -> s 1 (2 d)') cos = repeat(cos, 's d -> s 1 (2 d)')
sin = repeat(sin, 's d -> s 1 (2 d)') sin = repeat(sin, 's d -> s 1 (2 d)')
return torch.cat([x[..., :rotary_dim] * cos + rotate_half(x[..., :rotary_dim]) * sin, return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., rotary_dim:]], dim=-1) x[..., ro_dim:]], dim=-1)
class ApplyRotaryEmb(torch.autograd.Function): class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, cos, sin, inplace=False): def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
""" """
x: (batch_size, seqlen, nheads, headdim) x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x. Apply rotary embedding to the first rotary_dim of x.
""" """
@ -44,14 +50,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
assert rotary_dim <= headdim assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2) assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1) x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x out = torch.empty_like(x) if not inplace else x
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2) out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2]))
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False) rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
if not inplace and rotary_dim < headdim: if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:]) out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin) ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace ctx.inplace = inplace
return out if not inplace else x return out if not inplace else x
@ -62,14 +75,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
rotary_dim *= 2 rotary_dim *= 2
inplace = ctx.inplace inplace = ctx.inplace
do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1) do_ro = do[..., :rotary_dim]
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2]))
dx = torch.empty_like(do) if not inplace else do dx = torch.empty_like(do) if not inplace else do
dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2) if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2]))
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True) rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
if not inplace and rotary_dim < headdim: if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None return dx, None, None, None, None
apply_rotary_emb_func = ApplyRotaryEmb.apply apply_rotary_emb_func = ApplyRotaryEmb.apply
@ -78,11 +98,13 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class ApplyRotaryEmbQKV_(torch.autograd.Function): class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
""" """
qkv: (batch_size, seqlen, 3, nheads, headdim) qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k. Apply rotary embedding *inplace* to the first rotary_dim of q and k.
""" """
@ -95,13 +117,16 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
cos_k = cos if cos_k is None else cos_k cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False) rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False) rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
ctx.save_for_backward(cos, sin, cos_k, sin_k) ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv return qkv
@staticmethod @staticmethod
@ -110,13 +135,17 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
_, seqlen, _, _, headdim = dqkv.shape _, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
rotary_dim *= 2 rotary_dim *= 2
dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1) dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True) rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1) dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True) rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
return dqkv, None, None, None, None return dqkv, None, None, None, None, None
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
@ -135,22 +164,25 @@ class RotaryEmbedding(torch.nn.Module):
.. _repo: https://github.com/ZhuiyiTechnology/roformer .. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
""" """
def __init__(self, dim: int, base=10000, scale_base=0, device=None): def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None):
""" """
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
""" """
super().__init__() super().__init__()
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
dtype=torch.float32) / dim)) dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.interleaved = interleaved
self.scale_base = scale_base self.scale_base = scale_base
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base > 0 else None) / (1.4 * dim) if scale_base is not None else None)
self.register_buffer("scale", scale) self.register_buffer("scale", scale)
self._seq_len_cached = 0 self._seq_len_cached = 0
@ -187,16 +219,19 @@ class RotaryEmbedding(torch.nn.Module):
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
qkv: (batch, seqlen, 3, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch. token in the batch.
""" """
self._update_cos_sin_cache(qkv, seqlen_offset) self._update_cos_sin_cache(qkv, seqlen_offset)
if self.scale is None: if self.scale is None:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
None, None, self.interleaved
) )
else: else:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:] self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
self.interleaved
) )

114
tests/layers/test_rotary.py Normal file
View File

@ -0,0 +1,114 @@
# Copyright (c) 2023, Tri Dao.
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_neox
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_qkv_
from flash_attn.layers.rotary import RotaryEmbedding
# NeoX-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
def test_rotary(rotary_emb_fraction, seqlen_offset):
device = 'cuda'
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen_total = 2048
seqlen = seqlen_total - seqlen_offset
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, device=device)
rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
# Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
q_pt = rearrange(qkv[:, :, 0, :, :rotary_dim],
'b s h d -> b h s d').detach().clone().requires_grad_(True)
k_pt = rearrange(qkv[:, :, 1, :, :rotary_dim],
'b s h d -> b h s d').detach().clone().requires_grad_(True)
q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(rotary._cos_cached, cos_neox[..., :rotary_dim // 2].to(dtype=dtype),
rtol=rtol, atol=atol)
assert torch.allclose(rotary._sin_cached, sin_neox[..., :rotary_dim // 2].to(dtype=dtype),
rtol=rtol, atol=atol)
assert torch.allclose(rearrange(q_neox, 'b h s d -> b s h d'), out[:, :, 0, :, :rotary_dim],
rtol=rtol, atol=atol)
assert torch.allclose(rearrange(k_neox, 'b h s d -> b s h d'), out[:, :, 1, :, :rotary_dim],
rtol=rtol, atol=atol)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], 'b s h d -> b h s d'))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], 'b s h d -> b h s d'))
assert torch.allclose(rearrange(q_pt.grad, 'b h s d -> b s h d'),
qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.allclose(rearrange(k_pt.grad, 'b h s d -> b s h d'),
qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
# GPT-J-style rotary embedding
@pytest.mark.parametrize('seqlen_offset', [0, 711])
@pytest.mark.parametrize('rotary_emb_fraction', [0.5, 1.0])
def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
device = 'cuda'
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen_total = 2048
seqlen = seqlen_total - seqlen_offset
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj)
q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True)
k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True)
q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset)
k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol)
assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol)
assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g)
q_gptj.backward(g_og[:, :, 0, :, :rotary_dim])
k_gptj.backward(g_og[:, :, 1, :, :rotary_dim])
assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])