From b28ec236df49cb57c213a9a4f29beb727adbf65b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Sep 2023 17:40:00 -0700 Subject: [PATCH] [Rotary] Implement varlen rotary --- flash_attn/layers/rotary.py | 45 +++++++-- flash_attn/ops/triton/rotary.py | 51 +++++++--- tests/test_rotary.py | 174 ++++++++++++++++++++------------ 3 files changed, 181 insertions(+), 89 deletions(-) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 5f68a84..e081770 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -42,27 +42,37 @@ class ApplyRotaryEmb(torch.autograd.Function): interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, ): out = apply_rotary( - x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, ) if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward ctx.seqlen_offsets = seqlen_offsets else: - ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) ctx.seqlen_offsets = None ctx.interleaved = interleaved ctx.inplace = inplace + ctx.max_seqlen = max_seqlen return out if not inplace else x @staticmethod def backward(ctx, do): seqlen_offsets = ctx.seqlen_offsets if seqlen_offsets is None: - cos, sin, seqlen_offsets = ctx.saved_tensors + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: - cos, sin = ctx.saved_tensors + cos, sin, cu_seqlens = ctx.saved_tensors # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. if not ctx.interleaved and not ctx.inplace: @@ -72,31 +82,46 @@ class ApplyRotaryEmb(torch.autograd.Function): cos, sin, seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, interleaved=ctx.interleaved, inplace=ctx.inplace, conjugate=True, ) - return dx, None, None, None, None, None + return dx, None, None, None, None, None, None, None def apply_rotary_emb( - x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0 + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, ): """ Arguments: - x: (batch_size, seqlen, nheads, headdim) + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) cos, sin: (seqlen_rotary, rotary_dim / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). inplace: if True, apply rotary embedding in-place. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int Return: - out: (batch_size, seqlen, nheads, headdim) + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) rotary_dim must be <= headdim Apply rotary embedding to the first rotary_dim of x. """ - return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets) + return ApplyRotaryEmb.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) # For backward compatibility diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index b526981..ba846a0 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional, Union import torch @@ -21,6 +21,7 @@ def rotary_kernel( X, COS, SIN, + CU_SEQLENS, SEQLEN_OFFSETS, # this could be int or a pointer # Matrix dimensions seqlen, @@ -40,6 +41,7 @@ def rotary_kernel( # Meta-parameters BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr, @@ -49,9 +51,17 @@ def rotary_kernel( pid_head = tl.program_id(axis=2) rotary_dim_half = rotary_dim // 2 - X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + if pid_m * BLOCK_M >= seqlen: + return rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) if not IS_SEQLEN_OFFSETS_TENSOR: rm_cs = rm + SEQLEN_OFFSETS @@ -134,20 +144,33 @@ def apply_rotary( cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, interleaved=False, inplace=False, conjugate=False, ) -> torch.Tensor: """ Arguments: - x: (batch, seqlen, nheads, headdim) + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). cos: (seqlen_ro, rotary_dim / 2) sin: (seqlen_ro, rotary_dim / 2) seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int Returns: y: (batch, seqlen, nheads, headdim) """ - batch, seqlen, nheads, headdim = x.shape + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen seqlen_ro, rotary_dim = cos.shape assert sin.shape == cos.shape rotary_dim *= 2 @@ -187,22 +210,24 @@ def apply_rotary( x, cos, sin, + cu_seqlens, seqlen_offsets, seqlen, # shapes nheads, rotary_dim, seqlen_ro, seqlen // 128, # key for triton cache (limit number of compilations) - output.stride(0), # strides - output.stride(1), - output.stride(2), - output.stride(3), - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride BLOCK_K, isinstance(seqlen_offsets, torch.Tensor), + is_varlen, interleaved, conjugate, BLOCK_M, diff --git a/tests/test_rotary.py b/tests/test_rotary.py index f3213f5..574d052 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -7,10 +7,41 @@ import torch.nn.functional as F from einops import rearrange from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_ +from flash_attn.bert_padding import pad_input, unpad_input is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) +def generate_cos_sin(seqlen, rotary_dim, device, dtype): + assert rotary_dim % 2 == 0 + angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + return cos, sin + + +def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device): + if seqlen_offsets_type == 0: + return 0 + elif seqlen_offsets_type is int: + return torch.randint(0, seqlen + 1, (1,)).item() + elif seqlen_offsets_type is torch.Tensor: + return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device) + + +def index_cos_sin(cos, sin, seqlen_offsets, seqlen): + if isinstance(seqlen_offsets, torch.Tensor): + batch_size = seqlen_offsets.shape[0] + arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s") + idx = rearrange(seqlen_offsets, "b -> b 1") + arange + cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) + sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) + else: + cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] + sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] + return cos_pt, sin_pt + + @pytest.mark.parametrize( "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) @@ -30,35 +61,18 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t seqlen = 217 headdim = 128 device = "cuda" + rotary_dim = int(rotary_fraction * headdim) torch.manual_seed(42) x = torch.randn( batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True ) x_pt = x.detach().clone().requires_grad_() - rotary_dim = int(rotary_fraction * headdim) - assert rotary_dim % 2 == 0 - angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - if seqlen_offsets_type == 0: - seqlen_offsets = 0 - elif seqlen_offsets_type is int: - seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item() - elif seqlen_offsets_type is torch.Tensor: - seqlen_offsets = torch.randint( - 0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device - ) + cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) + seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) out = apply_rotary_emb( x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace ) - if seqlen_offsets_type is torch.Tensor: - arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") - idx = rearrange(seqlen_offsets, "b -> b 1") + arange - cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) - sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) - else: - cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] - sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] + cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) out_pt = apply_rotary_emb_torch( x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) @@ -96,35 +110,18 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype seqlen = 512 headdim = 128 device = "cuda" + rotary_dim = int(rotary_fraction * headdim) torch.manual_seed(42) qkv = torch.randn( batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True ) qkv_pt = qkv.detach().clone().requires_grad_() - rotary_dim = int(rotary_fraction * headdim) - assert rotary_dim % 2 == 0 - angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - if seqlen_offsets_type == 0: - seqlen_offsets = 0 - elif seqlen_offsets_type is int: - seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item() - elif seqlen_offsets_type is torch.Tensor: - seqlen_offsets = torch.randint( - 0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device - ) + cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) + seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) out = apply_rotary_emb_qkv_( qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved ) - if seqlen_offsets_type is torch.Tensor: - arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") - idx = rearrange(seqlen_offsets, "b -> b 1") + arange - cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) - sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) - else: - cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] - sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] + cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) q_pt = apply_rotary_emb_torch( qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) @@ -164,35 +161,16 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype) seqlen = 781 headdim = 64 device = "cuda" + rotary_dim = int(rotary_fraction * headdim) torch.manual_seed(42) kv = torch.randn( batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True ) kv_pt = kv.detach().clone().requires_grad_() - rotary_dim = int(rotary_fraction * headdim) - assert rotary_dim % 2 == 0 - angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - if seqlen_offsets_type == 0: - seqlen_offsets = 0 - elif seqlen_offsets_type is int: - seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item() - elif seqlen_offsets_type is torch.Tensor: - seqlen_offsets = torch.randint( - 0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device - ) - out = apply_rotary_emb_kv_( - kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved - ) - if seqlen_offsets_type is torch.Tensor: - arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") - idx = rearrange(seqlen_offsets, "b -> b 1") + arange - cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) - sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) - else: - cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] - sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] + cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) + seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) + out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved) + cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) k_pt = apply_rotary_emb_torch( kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) @@ -210,3 +188,67 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype) assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item() assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol) + + +@pytest.mark.parametrize( + "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) +) +# @pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) +# @pytest.mark.parametrize("seqlen_offsets_type", [0]) +@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) +# @pytest.mark.parametrize("rotary_fraction", [1.0]) +@pytest.mark.parametrize("interleaved", [False, True]) +# @pytest.mark.parametrize("interleaved", [True]) +@pytest.mark.parametrize("inplace", [False, True]) +# @pytest.mark.parametrize("inplace", [False]) +def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype): + rtol = 1e-3 + batch_size = 32 + nheads = 4 + seqlen = 217 + headdim = 128 + device = "cuda" + rotary_dim = int(rotary_fraction * headdim) + torch.manual_seed(42) + x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device) + x_pt = x.detach().clone().requires_grad_() + lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device) + padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths + x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask) + x_unpad_clone = x_unpad.clone() + x_unpad = x_unpad.requires_grad_() + cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) + seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) + out_unpad = apply_rotary_emb( + x_unpad, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=interleaved, + inplace=inplace, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + out = pad_input(out_unpad, indices, batch_size, seqlen) + cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) + out_pt = apply_rotary_emb_torch( + x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved + ).to(dtype=dtype) + out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) + print(f"Output max diff: {(out - out_pt).abs().max().item()}") + + g = torch.randn_like(out) + g_pt = g.clone() # If inplace=True, we might modify the gradient inplace + out.backward(g) + out_pt.backward(g_pt) + x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen) + print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}") + + if not inplace: + assert torch.equal(x_unpad, x_unpad_clone) + # Numerical error if we just do any arithmetic + atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() + assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) + atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() + assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)