[Rotary] Implement rotary in Triton
This commit is contained in:
parent
08e9847176
commit
942fcbf046
@ -1,11 +1,11 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import rotary_emb
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from flash_attn.ops.triton.rotary import apply_rotary
|
||||
|
||||
|
||||
def rotate_half(x, interleaved=False):
|
||||
@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False):
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(cos, "s d -> s 1 (2 d)")
|
||||
sin = repeat(sin, "s d -> s 1 (2 d)")
|
||||
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
return torch.cat(
|
||||
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
||||
dim=-1,
|
||||
@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
|
||||
class ApplyRotaryEmb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding to the first rotary_dim of x.
|
||||
"""
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
x_ro = x[..., :rotary_dim]
|
||||
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
|
||||
out = torch.empty_like(x) if not inplace else x
|
||||
out_ro = out[..., :rotary_dim]
|
||||
if inplace:
|
||||
o1, o2 = x1, x2
|
||||
else:
|
||||
o1, o2 = (
|
||||
out_ro.chunk(2, dim=-1)
|
||||
if not interleaved
|
||||
else (out_ro[..., ::2], out_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
x1,
|
||||
x2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
o1,
|
||||
o2,
|
||||
False,
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
out = apply_rotary(
|
||||
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
|
||||
)
|
||||
if not inplace and rotary_dim < headdim:
|
||||
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||
ctx.save_for_backward(cos, sin)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
ctx.inplace = inplace
|
||||
return out if not inplace else x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
cos, sin = ctx.saved_tensors
|
||||
_, seqlen, _, headdim = do.shape
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
inplace = ctx.inplace
|
||||
do_ro = do[..., :rotary_dim]
|
||||
do1, do2 = (
|
||||
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
|
||||
)
|
||||
dx = torch.empty_like(do) if not inplace else do
|
||||
if inplace:
|
||||
dx1, dx2 = do1, do2
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
dx_ro = dx[..., :rotary_dim]
|
||||
dx1, dx2 = (
|
||||
dx_ro.chunk(2, dim=-1)
|
||||
if not ctx.interleaved
|
||||
else (dx_ro[..., ::2], dx_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
do1,
|
||||
do2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
dx1,
|
||||
dx2,
|
||||
True,
|
||||
cos, sin = ctx.saved_tensors
|
||||
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
||||
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
||||
if not ctx.interleaved and not ctx.inplace:
|
||||
do = do.clone()
|
||||
dx = apply_rotary(
|
||||
do,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=ctx.inplace,
|
||||
conjugate=True,
|
||||
)
|
||||
if not inplace and rotary_dim < headdim:
|
||||
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
|
||||
return dx, None, None, None, None
|
||||
return dx, None, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_func = ApplyRotaryEmb.apply
|
||||
def apply_rotary_emb(
|
||||
x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
inplace: if True, apply rotary embedding in-place.
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding to the first rotary_dim of x.
|
||||
"""
|
||||
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets)
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
|
||||
|
||||
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
||||
"""
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
||||
"""
|
||||
def forward(
|
||||
ctx,
|
||||
qkv,
|
||||
cos,
|
||||
sin,
|
||||
cos_k=None,
|
||||
sin_k=None,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
batch, seqlen, three, nheads, headdim = qkv.shape
|
||||
assert three == 3
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
q_ro = qkv[:, :, 0, :, :rotary_dim]
|
||||
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(
|
||||
q1,
|
||||
q2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
q1,
|
||||
q2,
|
||||
False,
|
||||
)
|
||||
k_ro = qkv[:, :, 1, :, :rotary_dim]
|
||||
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(
|
||||
k1,
|
||||
k2,
|
||||
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
|
||||
k1,
|
||||
k2,
|
||||
False,
|
||||
)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
||||
# Call 1 kernel instead of 2 kernels
|
||||
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
||||
# dimensions, we get the same tensor
|
||||
qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
|
||||
apply_rotary(
|
||||
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
|
||||
)
|
||||
else:
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
q, k = qkv[:, :, 0], qkv[:, :, 1]
|
||||
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dqkv):
|
||||
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
||||
_, seqlen, _, _, headdim = dqkv.shape
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
||||
dq1, dq2 = (
|
||||
dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
dq1,
|
||||
dq2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
dq1,
|
||||
dq2,
|
||||
True,
|
||||
)
|
||||
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
||||
dk1, dk2 = (
|
||||
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
dk1,
|
||||
dk2,
|
||||
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
|
||||
dk1,
|
||||
dk2,
|
||||
True,
|
||||
)
|
||||
return dqkv, None, None, None, None, None
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
||||
if cos_k is None and sin_k is None and dqkv.is_contiguous():
|
||||
# Call 1 kernel instead of 2 kernels
|
||||
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
|
||||
# dimensions, we get the same tensor
|
||||
dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
|
||||
apply_rotary(
|
||||
dqk,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
else:
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
|
||||
apply_rotary(
|
||||
dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True
|
||||
)
|
||||
apply_rotary(
|
||||
dk,
|
||||
cos_k,
|
||||
sin_k,
|
||||
seqlen_offsets,
|
||||
interleaved=interleaved,
|
||||
inplace=True,
|
||||
conjudate=True,
|
||||
)
|
||||
return dqkv, None, None, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
||||
def apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
cos,
|
||||
sin,
|
||||
cos_k=None,
|
||||
sin_k=None,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
Return:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
|
||||
"""
|
||||
return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
|
||||
|
||||
|
||||
class ApplyRotaryEmbKV_(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kv, cos, sin, interleaved=False):
|
||||
"""
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of k.
|
||||
"""
|
||||
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
|
||||
batch, seqlen, two, nheads, headdim = kv.shape
|
||||
assert two == 2
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
k_ro = kv[:, :, 0, :, :rotary_dim]
|
||||
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(
|
||||
k1,
|
||||
k2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
k1,
|
||||
k2,
|
||||
False,
|
||||
) # conj=False since this is the forward pass
|
||||
ctx.save_for_backward(cos, sin)
|
||||
k = kv[:, :, 0]
|
||||
apply_rotary(
|
||||
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
|
||||
)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
return kv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dkv):
|
||||
cos, sin = ctx.saved_tensors
|
||||
_, seqlen, _, _, headdim = dkv.shape
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
dk_ro = dkv[:, :, 0, :, :rotary_dim]
|
||||
dk1, dk2 = (
|
||||
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin = ctx.saved_tensors
|
||||
apply_rotary(
|
||||
dkv[:, :, 0],
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=True,
|
||||
conjugate=True,
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
dk1,
|
||||
dk2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
dk1,
|
||||
dk2,
|
||||
True,
|
||||
) # conj=True since this is the backward pass
|
||||
return dkv, None, None, None
|
||||
return dkv, None, None, None, None
|
||||
|
||||
|
||||
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
|
||||
|
||||
|
||||
def apply_rotary_emb_kv_(
|
||||
kv,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
||||
1st half and 2nd half (GPT-NeoX style).
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
Return:
|
||||
kv: (batch_size, seqlen, 2, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding *inplace* to the first rotary_dim of K.
|
||||
"""
|
||||
return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
The rotary position embeddings from RoFormer_ (Su et. al).
|
||||
@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(
|
||||
self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0
|
||||
self,
|
||||
qkv: torch.Tensor,
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
seqlen_offset: Union[int, torch.Tensor] = 0,
|
||||
max_seqlen: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
||||
else it's just q of shape (batch, seqlen, nheads, headdim)
|
||||
kv: (batch, seqlen, 2, nheads, headdim)
|
||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||
token in the batch.
|
||||
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
||||
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
||||
Apply rotary embedding *inplace* to qkv and / or kv.
|
||||
"""
|
||||
seqlen = qkv.shape[1]
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
if isinstance(seqlen_offset, int):
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
elif max_seqlen is not None:
|
||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||
if kv is None:
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
None,
|
||||
None,
|
||||
self.interleaved,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
else:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
self._cos_k_cached[seqlen_offset:],
|
||||
self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
self._cos_k_cached,
|
||||
self._sin_k_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
else:
|
||||
q = qkv
|
||||
q = apply_rotary_emb_func(
|
||||
q,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
True,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
inplace=True,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
if self.scale is None:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
else:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv,
|
||||
self._cos_k_cached[seqlen_offset:],
|
||||
self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
self._cos_k_cached,
|
||||
self._sin_k_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
return q, kv
|
||||
|
||||
@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
|
||||
# We don't store these biases
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.bias")
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
|
||||
# We don't store these
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
|
||||
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
||||
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
||||
headdim = config.hidden_size // config.num_attention_heads
|
||||
@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
|
||||
r"transformer.layers.\1.mixer.out_proj.",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).attention.rotary_emb.",
|
||||
r"transformer.layers.\1.mixer.rotary_emb.",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
|
||||
# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
|
||||
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.autograd.function import FunctionCtx
|
||||
from torch.cuda.amp import custom_fwd
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
from flash_attn.ops.triton.k_activations import (
|
||||
|
||||
182
flash_attn/ops/triton/rotary.py
Normal file
182
flash_attn/ops/triton/rotary.py
Normal file
@ -0,0 +1,182 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_M": 2}),
|
||||
# triton.Config({"BLOCK_M": 4}),
|
||||
# triton.Config({"BLOCK_M": 8}),
|
||||
# triton.Config({"BLOCK_M": 16}),
|
||||
# ],
|
||||
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
|
||||
# )
|
||||
@triton.jit
|
||||
def rotary_kernel(
|
||||
OUT, # Pointers to matrices
|
||||
X,
|
||||
COS,
|
||||
SIN,
|
||||
SEQLEN_OFFSETS, # this could be int or a pointer
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
CACHE_KEY_SEQLEN,
|
||||
# strides
|
||||
stride_out_batch,
|
||||
stride_out_seqlen,
|
||||
stride_out_nheads,
|
||||
stride_out_headdim,
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_nheads,
|
||||
stride_x_headdim,
|
||||
# Meta-parameters
|
||||
BLOCK_K: tl.constexpr,
|
||||
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
||||
INTERLEAVED: tl.constexpr,
|
||||
CONJUGATE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_batch = tl.program_id(axis=1)
|
||||
pid_head = tl.program_id(axis=2)
|
||||
rotary_dim_half = rotary_dim // 2
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rk = tl.arange(0, BLOCK_K // 2)
|
||||
if not IS_SEQLEN_OFFSETS_TENSOR:
|
||||
rm_cs = rm + SEQLEN_OFFSETS
|
||||
else:
|
||||
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
||||
|
||||
X = X + (
|
||||
pid_batch * stride_x_batch
|
||||
+ rm[:, None] * stride_x_seqlen
|
||||
+ pid_head * stride_x_nheads
|
||||
+ rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1)
|
||||
)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :])
|
||||
|
||||
cos = tl.load(
|
||||
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
x1 = tl.load(
|
||||
X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half),
|
||||
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if not CONJUGATE:
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
else:
|
||||
o0 = x0 * cos + x1 * sin
|
||||
o1 = -x0 * sin + x1 * cos
|
||||
|
||||
# write back result
|
||||
OUT = OUT + (
|
||||
pid_batch * stride_out_batch
|
||||
+ rm[:, None] * stride_out_seqlen
|
||||
+ pid_head * stride_out_nheads
|
||||
+ rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1)
|
||||
)
|
||||
tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half))
|
||||
tl.store(
|
||||
OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half),
|
||||
o1,
|
||||
mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half),
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
conjugate=False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
cos: (seqlen_ro, rotary_dim / 2)
|
||||
sin: (seqlen_ro, rotary_dim / 2)
|
||||
seqlen_offsets: integer or integer tensor of size (batch,)
|
||||
Returns:
|
||||
y: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
seqlen_ro, rotary_dim = cos.shape
|
||||
assert sin.shape == cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
||||
assert headdim <= 256, "Only support headdim <= 256"
|
||||
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
||||
|
||||
assert (
|
||||
cos.dtype == sin.dtype
|
||||
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
||||
assert (
|
||||
x.dtype == cos.dtype
|
||||
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
||||
|
||||
cos, sin = cos.contiguous(), sin.contiguous()
|
||||
if isinstance(seqlen_offsets, torch.Tensor):
|
||||
assert seqlen_offsets.shape == (batch,)
|
||||
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
||||
seqlen_offsets = seqlen_offsets.contiguous()
|
||||
else:
|
||||
assert seqlen_offsets + seqlen <= seqlen_ro
|
||||
|
||||
output = torch.empty_like(x) if not inplace else x
|
||||
if rotary_dim < headdim and not inplace:
|
||||
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||
|
||||
BLOCK_K = (
|
||||
32
|
||||
if rotary_dim <= 32
|
||||
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
||||
)
|
||||
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
|
||||
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
||||
|
||||
rotary_kernel[grid](
|
||||
output, # data ptrs
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
seqlen // 128, # key for triton cache (limit number of compilations)
|
||||
output.stride(0), # strides
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
output.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
x.stride(3),
|
||||
BLOCK_K,
|
||||
isinstance(seqlen_offsets, torch.Tensor),
|
||||
interleaved,
|
||||
conjugate,
|
||||
BLOCK_M,
|
||||
)
|
||||
return output
|
||||
@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
)
|
||||
print(out_cg.sequences)
|
||||
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
if not rotary:
|
||||
out_hf = model_hf.generate(
|
||||
input_ids=input_ids,
|
||||
@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
).abs().max().item() < 3 * (
|
||||
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
|
||||
).abs().max().item()
|
||||
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
|
||||
from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
|
||||
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
|
||||
|
||||
@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
|
||||
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
|
||||
)
|
||||
# @pytest.mark.parametrize('dtype', ([torch.float16]))
|
||||
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
|
||||
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
|
||||
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
|
||||
# @pytest.mark.parametrize('rotary_fraction', [0.5])
|
||||
# @pytest.mark.parametrize('rotary_fraction', [1.0])
|
||||
@pytest.mark.parametrize("interleaved", [False, True])
|
||||
# @pytest.mark.parametrize('interleaved', [False])
|
||||
@pytest.mark.parametrize("inplace", [False, True])
|
||||
# @pytest.mark.parametrize('inplace', [False])
|
||||
def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
|
||||
def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
|
||||
rtol = 1e-3
|
||||
batch_size = 32
|
||||
nheads = 4
|
||||
seqlen = 217
|
||||
headdim = 128
|
||||
device = "cuda"
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(
|
||||
batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True
|
||||
batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
x_pt = x.detach().clone().requires_grad_()
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
assert rotary_dim % 2 == 0
|
||||
angle = torch.randn(seqlen, rotary_dim // 2, device="cuda")
|
||||
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
out = apply_rotary_emb_func(x, cos, sin, inplace)
|
||||
out_pt = apply_rotary_emb_torch(x_pt, cos, sin)
|
||||
# Numerical error if we just do any arithmetic
|
||||
atol = ((out + 0.3 - 0.3) - out).abs().max().item()
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
|
||||
if seqlen_offsets_type == 0:
|
||||
seqlen_offsets = 0
|
||||
elif seqlen_offsets_type is int:
|
||||
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
|
||||
elif seqlen_offsets_type is torch.Tensor:
|
||||
seqlen_offsets = torch.randint(
|
||||
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
out = apply_rotary_emb(
|
||||
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
|
||||
)
|
||||
if seqlen_offsets_type is torch.Tensor:
|
||||
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
|
||||
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
|
||||
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
else:
|
||||
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
out_pt = apply_rotary_emb_torch(
|
||||
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
|
||||
out.backward(g)
|
||||
out_pt.backward(g_pt)
|
||||
print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
|
||||
|
||||
if not inplace:
|
||||
assert torch.equal(x, x_pt)
|
||||
# Numerical error if we just do any arithmetic
|
||||
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
|
||||
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
|
||||
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
|
||||
)
|
||||
# @pytest.mark.parametrize('dtype', ([torch.float16]))
|
||||
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
|
||||
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
|
||||
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
|
||||
# @pytest.mark.parametrize('rotary_fraction', [1.0])
|
||||
@pytest.mark.parametrize("interleaved", [False, True])
|
||||
# @pytest.mark.parametrize('interleaved', [False])
|
||||
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
|
||||
rtol = 1e-3
|
||||
batch_size = 32
|
||||
nheads = 4
|
||||
seqlen = 512
|
||||
headdim = 128
|
||||
device = "cuda"
|
||||
torch.manual_seed(42)
|
||||
qkv = torch.randn(
|
||||
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
qkv_pt = qkv.detach().clone().requires_grad_()
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
assert rotary_dim % 2 == 0
|
||||
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
if seqlen_offsets_type == 0:
|
||||
seqlen_offsets = 0
|
||||
elif seqlen_offsets_type is int:
|
||||
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
|
||||
elif seqlen_offsets_type is torch.Tensor:
|
||||
seqlen_offsets = torch.randint(
|
||||
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
out = apply_rotary_emb_qkv_(
|
||||
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
|
||||
)
|
||||
if seqlen_offsets_type is torch.Tensor:
|
||||
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
|
||||
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
|
||||
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
else:
|
||||
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
q_pt = apply_rotary_emb_torch(
|
||||
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
k_pt = apply_rotary_emb_torch(
|
||||
qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
|
||||
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
|
||||
out.backward(g)
|
||||
out_pt.backward(g_pt)
|
||||
print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
|
||||
|
||||
# Numerical error if we just do any arithmetic
|
||||
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
|
||||
atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
|
||||
assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
|
||||
)
|
||||
# @pytest.mark.parametrize('dtype', ([torch.float16]))
|
||||
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
|
||||
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
|
||||
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
|
||||
# @pytest.mark.parametrize('rotary_fraction', [1.0])
|
||||
@pytest.mark.parametrize("interleaved", [False, True])
|
||||
# @pytest.mark.parametrize('interleaved', [False])
|
||||
def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
|
||||
rtol = 1e-3
|
||||
batch_size = 32
|
||||
nheads = 4
|
||||
seqlen = 781
|
||||
headdim = 64
|
||||
device = "cuda"
|
||||
torch.manual_seed(42)
|
||||
kv = torch.randn(
|
||||
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
kv_pt = kv.detach().clone().requires_grad_()
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
assert rotary_dim % 2 == 0
|
||||
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
if seqlen_offsets_type == 0:
|
||||
seqlen_offsets = 0
|
||||
elif seqlen_offsets_type is int:
|
||||
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
|
||||
elif seqlen_offsets_type is torch.Tensor:
|
||||
seqlen_offsets = torch.randint(
|
||||
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
out = apply_rotary_emb_kv_(
|
||||
kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
|
||||
)
|
||||
if seqlen_offsets_type is torch.Tensor:
|
||||
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
|
||||
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
|
||||
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
else:
|
||||
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
k_pt = apply_rotary_emb_torch(
|
||||
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
|
||||
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
|
||||
out.backward(g)
|
||||
out_pt.backward(g_pt)
|
||||
print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
|
||||
|
||||
# Numerical error if we just do any arithmetic
|
||||
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
|
||||
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
|
||||
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user