[Rotary] Implement varlen rotary
This commit is contained in:
parent
861c82577d
commit
b28ec236df
@ -42,27 +42,37 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
out = apply_rotary(
|
||||
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
interleaved=interleaved,
|
||||
inplace=inplace,
|
||||
)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, seqlen_offsets)
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
ctx.inplace = inplace
|
||||
ctx.max_seqlen = max_seqlen
|
||||
return out if not inplace else x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, seqlen_offsets = ctx.saved_tensors
|
||||
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin = ctx.saved_tensors
|
||||
cos, sin, cu_seqlens = ctx.saved_tensors
|
||||
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
||||
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
||||
if not ctx.interleaved and not ctx.inplace:
|
||||
@ -72,31 +82,46 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=ctx.max_seqlen,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=ctx.inplace,
|
||||
conjugate=True,
|
||||
)
|
||||
return dx, None, None, None, None, None
|
||||
return dx, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
inplace: if True, apply rotary embedding in-place.
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim)
|
||||
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding to the first rotary_dim of x.
|
||||
"""
|
||||
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets)
|
||||
return ApplyRotaryEmb.apply(
|
||||
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
||||
)
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -21,6 +21,7 @@ def rotary_kernel(
|
||||
X,
|
||||
COS,
|
||||
SIN,
|
||||
CU_SEQLENS,
|
||||
SEQLEN_OFFSETS, # this could be int or a pointer
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
@ -40,6 +41,7 @@ def rotary_kernel(
|
||||
# Meta-parameters
|
||||
BLOCK_K: tl.constexpr,
|
||||
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
INTERLEAVED: tl.constexpr,
|
||||
CONJUGATE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
@ -49,9 +51,17 @@ def rotary_kernel(
|
||||
pid_head = tl.program_id(axis=2)
|
||||
rotary_dim_half = rotary_dim // 2
|
||||
|
||||
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
||||
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
||||
if not IS_VARLEN:
|
||||
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
||||
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
||||
else:
|
||||
start_idx = tl.load(CU_SEQLENS + pid_batch)
|
||||
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
||||
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
||||
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
||||
|
||||
if pid_m * BLOCK_M >= seqlen:
|
||||
return
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
if not IS_SEQLEN_OFFSETS_TENSOR:
|
||||
rm_cs = rm + SEQLEN_OFFSETS
|
||||
@ -134,20 +144,33 @@ def apply_rotary(
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
conjugate=False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim).
|
||||
cos: (seqlen_ro, rotary_dim / 2)
|
||||
sin: (seqlen_ro, rotary_dim / 2)
|
||||
seqlen_offsets: integer or integer tensor of size (batch,)
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Returns:
|
||||
y: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
is_varlen = cu_seqlens is not None
|
||||
if not is_varlen:
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
else:
|
||||
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
batch_p_1 = cu_seqlens.shape[0]
|
||||
batch = batch_p_1 - 1
|
||||
seqlen = max_seqlen
|
||||
seqlen_ro, rotary_dim = cos.shape
|
||||
assert sin.shape == cos.shape
|
||||
rotary_dim *= 2
|
||||
@ -187,22 +210,24 @@ def apply_rotary(
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
seqlen // 128, # key for triton cache (limit number of compilations)
|
||||
output.stride(0), # strides
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
output.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
x.stride(3),
|
||||
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
||||
output.stride(-2), # nheads_stride
|
||||
output.stride(-1), # headdim_stride
|
||||
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
x.stride(-3), # seqlen stride or total_seqlen_stride
|
||||
x.stride(-2), # nheads stride
|
||||
x.stride(-1), # headdim stride
|
||||
BLOCK_K,
|
||||
isinstance(seqlen_offsets, torch.Tensor),
|
||||
is_varlen,
|
||||
interleaved,
|
||||
conjugate,
|
||||
BLOCK_M,
|
||||
|
||||
@ -7,10 +7,41 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
|
||||
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
|
||||
|
||||
|
||||
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
|
||||
assert rotary_dim % 2 == 0
|
||||
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device):
|
||||
if seqlen_offsets_type == 0:
|
||||
return 0
|
||||
elif seqlen_offsets_type is int:
|
||||
return torch.randint(0, seqlen + 1, (1,)).item()
|
||||
elif seqlen_offsets_type is torch.Tensor:
|
||||
return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
def index_cos_sin(cos, sin, seqlen_offsets, seqlen):
|
||||
if isinstance(seqlen_offsets, torch.Tensor):
|
||||
batch_size = seqlen_offsets.shape[0]
|
||||
arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s")
|
||||
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
|
||||
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
else:
|
||||
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
return cos_pt, sin_pt
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
|
||||
)
|
||||
@ -30,35 +61,18 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t
|
||||
seqlen = 217
|
||||
headdim = 128
|
||||
device = "cuda"
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(
|
||||
batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
x_pt = x.detach().clone().requires_grad_()
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
assert rotary_dim % 2 == 0
|
||||
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
if seqlen_offsets_type == 0:
|
||||
seqlen_offsets = 0
|
||||
elif seqlen_offsets_type is int:
|
||||
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
|
||||
elif seqlen_offsets_type is torch.Tensor:
|
||||
seqlen_offsets = torch.randint(
|
||||
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
|
||||
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
|
||||
out = apply_rotary_emb(
|
||||
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
|
||||
)
|
||||
if seqlen_offsets_type is torch.Tensor:
|
||||
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
|
||||
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
|
||||
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
else:
|
||||
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
|
||||
out_pt = apply_rotary_emb_torch(
|
||||
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
@ -96,35 +110,18 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype
|
||||
seqlen = 512
|
||||
headdim = 128
|
||||
device = "cuda"
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
torch.manual_seed(42)
|
||||
qkv = torch.randn(
|
||||
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
qkv_pt = qkv.detach().clone().requires_grad_()
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
assert rotary_dim % 2 == 0
|
||||
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
if seqlen_offsets_type == 0:
|
||||
seqlen_offsets = 0
|
||||
elif seqlen_offsets_type is int:
|
||||
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
|
||||
elif seqlen_offsets_type is torch.Tensor:
|
||||
seqlen_offsets = torch.randint(
|
||||
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
|
||||
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
|
||||
out = apply_rotary_emb_qkv_(
|
||||
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
|
||||
)
|
||||
if seqlen_offsets_type is torch.Tensor:
|
||||
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
|
||||
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
|
||||
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
else:
|
||||
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
|
||||
q_pt = apply_rotary_emb_torch(
|
||||
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
@ -164,35 +161,16 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype)
|
||||
seqlen = 781
|
||||
headdim = 64
|
||||
device = "cuda"
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
torch.manual_seed(42)
|
||||
kv = torch.randn(
|
||||
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
kv_pt = kv.detach().clone().requires_grad_()
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
assert rotary_dim % 2 == 0
|
||||
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
if seqlen_offsets_type == 0:
|
||||
seqlen_offsets = 0
|
||||
elif seqlen_offsets_type is int:
|
||||
seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item()
|
||||
elif seqlen_offsets_type is torch.Tensor:
|
||||
seqlen_offsets = torch.randint(
|
||||
0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
out = apply_rotary_emb_kv_(
|
||||
kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
|
||||
)
|
||||
if seqlen_offsets_type is torch.Tensor:
|
||||
arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s")
|
||||
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
|
||||
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
|
||||
else:
|
||||
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
|
||||
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
|
||||
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
|
||||
out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved)
|
||||
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
|
||||
k_pt = apply_rotary_emb_torch(
|
||||
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
@ -210,3 +188,67 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype)
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
|
||||
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
|
||||
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", ([torch.float16]))
|
||||
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
|
||||
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
|
||||
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
|
||||
# @pytest.mark.parametrize("rotary_fraction", [1.0])
|
||||
@pytest.mark.parametrize("interleaved", [False, True])
|
||||
# @pytest.mark.parametrize("interleaved", [True])
|
||||
@pytest.mark.parametrize("inplace", [False, True])
|
||||
# @pytest.mark.parametrize("inplace", [False])
|
||||
def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
|
||||
rtol = 1e-3
|
||||
batch_size = 32
|
||||
nheads = 4
|
||||
seqlen = 217
|
||||
headdim = 128
|
||||
device = "cuda"
|
||||
rotary_dim = int(rotary_fraction * headdim)
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
|
||||
x_pt = x.detach().clone().requires_grad_()
|
||||
lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
|
||||
padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
|
||||
x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask)
|
||||
x_unpad_clone = x_unpad.clone()
|
||||
x_unpad = x_unpad.requires_grad_()
|
||||
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
|
||||
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
|
||||
out_unpad = apply_rotary_emb(
|
||||
x_unpad,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
interleaved=interleaved,
|
||||
inplace=inplace,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
out = pad_input(out_unpad, indices, batch_size, seqlen)
|
||||
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
|
||||
out_pt = apply_rotary_emb_torch(
|
||||
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
||||
).to(dtype=dtype)
|
||||
out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
|
||||
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
|
||||
out.backward(g)
|
||||
out_pt.backward(g_pt)
|
||||
x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen)
|
||||
print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}")
|
||||
|
||||
if not inplace:
|
||||
assert torch.equal(x_unpad, x_unpad_clone)
|
||||
# Numerical error if we just do any arithmetic
|
||||
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
|
||||
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
|
||||
assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user