[Rotary] Implement varlen rotary

This commit is contained in:
Tri Dao 2023-09-03 17:40:00 -07:00
parent 861c82577d
commit b28ec236df
3 changed files with 181 additions and 89 deletions

View File

@ -42,27 +42,37 @@ class ApplyRotaryEmb(torch.autograd.Function):
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
out = apply_rotary(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
x,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=inplace,
)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, seqlen_offsets)
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.inplace = inplace
ctx.max_seqlen = max_seqlen
return out if not inplace else x
@staticmethod
def backward(ctx, do):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, seqlen_offsets = ctx.saved_tensors
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
else:
cos, sin = ctx.saved_tensors
cos, sin, cu_seqlens = ctx.saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved and not ctx.inplace:
@ -72,31 +82,46 @@ class ApplyRotaryEmb(torch.autograd.Function):
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=ctx.inplace,
conjugate=True,
)
return dx, None, None, None, None, None
return dx, None, None, None, None, None, None, None
def apply_rotary_emb(
x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim)
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim)
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets)
return ApplyRotaryEmb.apply(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)
# For backward compatibility

View File

@ -1,4 +1,4 @@
from typing import Union
from typing import Optional, Union
import torch
@ -21,6 +21,7 @@ def rotary_kernel(
X,
COS,
SIN,
CU_SEQLENS,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
@ -40,6 +41,7 @@ def rotary_kernel(
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
@ -49,9 +51,17 @@ def rotary_kernel(
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
if not IS_VARLEN:
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
else:
start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
if pid_m * BLOCK_M >= seqlen:
return
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
@ -134,20 +144,33 @@ def apply_rotary(
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch, seqlen, nheads, headdim = x.shape
is_varlen = cu_seqlens is not None
if not is_varlen:
batch, seqlen, nheads, headdim = x.shape
else:
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen, nheads, headdim = x.shape
batch_p_1 = cu_seqlens.shape[0]
batch = batch_p_1 - 1
seqlen = max_seqlen
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
@ -187,22 +210,24 @@ def apply_rotary(
x,
cos,
sin,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0), # strides
output.stride(1),
output.stride(2),
output.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate,
BLOCK_M,

View File

@ -7,10 +7,41 @@ import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
from flash_attn.bert_padding import pad_input, unpad_input
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
return cos, sin
def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device):
if seqlen_offsets_type == 0:
return 0
elif seqlen_offsets_type is int:
return torch.randint(0, seqlen + 1, (1,)).item()
elif seqlen_offsets_type is torch.Tensor:
return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device)
def index_cos_sin(cos, sin, seqlen_offsets, seqlen):
if isinstance(seqlen_offsets, torch.Tensor):
batch_size = seqlen_offsets.shape[0]
arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
return cos_pt, sin_pt
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
@ -30,35 +61,18 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t
seqlen = 217
headdim = 128
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
x = torch.randn(
batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out = apply_rotary_emb(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
@ -96,35 +110,18 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype
seqlen = 512
headdim = 128
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
qkv_pt = qkv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out = apply_rotary_emb_qkv_(
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
q_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
@ -164,35 +161,16 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype)
seqlen = 781
headdim = 64
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
kv = torch.randn(
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
kv_pt = kv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_kv_(
kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved)
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
k_pt = apply_rotary_emb_torch(
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
@ -210,3 +188,67 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype)
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize("dtype", ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize("rotary_fraction", [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize("interleaved", [True])
@pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize("inplace", [False])
def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 217
headdim = 128
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
x_pt = x.detach().clone().requires_grad_()
lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask)
x_unpad_clone = x_unpad.clone()
x_unpad = x_unpad.requires_grad_()
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out_unpad = apply_rotary_emb(
x_unpad,
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=interleaved,
inplace=inplace,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
out = pad_input(out_unpad, indices, batch_size, seqlen)
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen)
print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}")
if not inplace:
assert torch.equal(x_unpad, x_unpad_clone)
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)