[Rotary] Implement rotary in Triton

This commit is contained in:
Tri Dao 2023-09-03 02:44:59 -07:00
parent 08e9847176
commit 942fcbf046
6 changed files with 601 additions and 231 deletions

View File

@ -1,11 +1,11 @@
# Copyright (c) 2023, Tri Dao.
import math
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import rotary_emb
import torch
from einops import rearrange, repeat
from flash_attn.ops.triton.rotary import apply_rotary
def rotate_half(x, interleaved=False):
@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False):
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "s d -> s 1 (2 d)")
sin = repeat(sin, "s d -> s 1 (2 d)")
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
def forward(
ctx,
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
out = apply_rotary(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x
@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, seqlen_offsets = ctx.saved_tensors
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
rotary_emb.apply_rotary(
do1,
do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dx1,
dx2,
True,
cos, sin = ctx.saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved and not ctx.inplace:
do = do.clone()
dx = apply_rotary(
do,
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=ctx.inplace,
conjugate=True,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
return dx, None, None, None, None, None
apply_rotary_emb_func = ApplyRotaryEmb.apply
def apply_rotary_emb(
x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
out: (batch_size, seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets)
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
def forward(
ctx,
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
rotary_emb.apply_rotary(
q1,
q2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
q1,
q2,
False,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(
k1,
k2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
if cos_k is None and sin_k is None and qkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
apply_rotary(
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
q, k = qkv[:, :, 0], qkv[:, :, 1]
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
return qkv
@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = (
dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dq1,
dq2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dq1,
dq2,
True,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = (
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
)
rotary_emb.apply_rotary(
dk1,
dk2,
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
)
return dqkv, None, None, None, None, None
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cos_k, sin_k = ctx.saved_tensors
if cos_k is None and sin_k is None and dqkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
apply_rotary(
dqk,
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
apply_rotary(
dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True
)
apply_rotary(
dk,
cos_k,
sin_k,
seqlen_offsets,
interleaved=interleaved,
inplace=True,
conjudate=True,
)
return dqkv, None, None, None, None, None, None
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
def apply_rotary_emb_qkv_(
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
qkv: (batch_size, seqlen, 3, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
"""
return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
k_ro = kv[:, :, 0, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
rotary_emb.apply_rotary(
k1,
k2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
k1,
k2,
False,
) # conj=False since this is the forward pass
ctx.save_for_backward(cos, sin)
k = kv[:, :, 0]
apply_rotary(
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
return kv
@staticmethod
def backward(ctx, dkv):
cos, sin = ctx.saved_tensors
_, seqlen, _, _, headdim = dkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dk_ro = dkv[:, :, 0, :, :rotary_dim]
dk1, dk2 = (
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, seqlen_offsets = ctx.saved_tensors
else:
cos, sin = ctx.saved_tensors
apply_rotary(
dkv[:, :, 0],
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
rotary_emb.apply_rotary(
dk1,
dk2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dk1,
dk2,
True,
) # conj=True since this is the backward pass
return dkv, None, None, None
return dkv, None, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
def apply_rotary_emb_kv_(
kv,
cos,
sin,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
kv: (batch_size, seqlen, 2, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of K.
"""
return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(
self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0
self,
qkv: torch.Tensor,
kv: Optional[torch.Tensor] = None,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen = qkv.shape[1]
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
elif max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
None,
None,
self.interleaved,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
self._cos_cached,
self._sin_cached,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
q = qkv
q = apply_rotary_emb_func(
q,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
True,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
inplace=True,
seqlen_offsets=seqlen_offset,
)
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv,
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
self.interleaved,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
kv = apply_rotary_emb_kv_(
kv,
self._cos_k_cached[seqlen_offset:],
self._sin_k_cached[seqlen_offset:],
self.interleaved,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
return q, kv

View File

@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases
state_dict.pop(f"transformer.layers.{l}.attention.bias")
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
# We don't store these
state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r"transformer.layers.\1.mixer.out_proj.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).attention.rotary_emb.",
r"transformer.layers.\1.mixer.rotary_emb.",
key,
)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

View File

@ -1,12 +1,10 @@
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from typing import Optional
import torch
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
from flash_attn.ops.triton.k_activations import (

View File

@ -0,0 +1,182 @@
from typing import Union
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@triton.jit
def rotary_kernel(
OUT, # Pointers to matrices
X,
COS,
SIN,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
nheads,
rotary_dim,
seqlen_ro,
CACHE_KEY_SEQLEN,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = tl.arange(0, BLOCK_K // 2)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
X = X + (
pid_batch * stride_x_batch
+ rm[:, None] * stride_x_seqlen
+ pid_head * stride_x_nheads
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1)
)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
tl.float32
)
x1 = tl.load(
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half),
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if not CONJUGATE:
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
else:
o0 = x0 * cos + x1 * sin
o1 = -x0 * sin + x1 * cos
# write back result
OUT = OUT + (
pid_batch * stride_out_batch
+ rm[:, None] * stride_out_seqlen
+ pid_head * stride_out_nheads
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
tl.store(
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half),
o1,
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
)
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch, seqlen, nheads, headdim = x.shape
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
BLOCK_K = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(0), # strides
output.stride(1),
output.stride(2),
output.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
interleaved,
conjugate,
BLOCK_M,
)
return output

View File

@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
)
print(out_cg.sequences)
parallel_state.destroy_model_parallel()
if not rotary:
out_hf = model_hf.generate(
input_ids=input_ids,
@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
parallel_state.destroy_model_parallel()

View File

@ -1,10 +1,12 @@
import math
import random
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
@pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 217
headdim = 128
device = "cuda"
torch.manual_seed(42)
x = torch.randn(
batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True
batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
x_pt = x.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.randn(seqlen, rotary_dim // 2, device="cuda")
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
out = apply_rotary_emb_func(x, cos, sin, inplace)
out_pt = apply_rotary_emb_torch(x_pt, cos, sin)
# Numerical error if we just do any arithmetic
atol = ((out + 0.3 - 0.3) - out).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
if not inplace:
assert torch.equal(x, x_pt)
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 512
headdim = 128
device = "cuda"
torch.manual_seed(42)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
qkv_pt = qkv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_qkv_(
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
q_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
k_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 781
headdim = 64
device = "cuda"
torch.manual_seed(42)
kv = torch.randn(
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
kv_pt = kv.detach().clone().requires_grad_()
rotary_dim = int(rotary_fraction * headdim)
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if seqlen_offsets_type == 0:
seqlen_offsets = 0
elif seqlen_offsets_type is int:
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
elif seqlen_offsets_type is torch.Tensor:
seqlen_offsets = torch.randint(
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
)
out = apply_rotary_emb_kv_(
kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
if seqlen_offsets_type is torch.Tensor:
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
k_pt = apply_rotary_emb_torch(
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)