From 942fcbf0463962173a9b0133bacadfa4783a65dc Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Sep 2023 02:44:59 -0700 Subject: [PATCH] [Rotary] Implement rotary in Triton --- flash_attn/layers/rotary.py | 448 ++++++++++--------- flash_attn/models/gpt_neox.py | 7 +- flash_attn/ops/triton/linear.py | 4 +- flash_attn/ops/triton/rotary.py | 182 ++++++++ tests/models/test_gpt_generation_parallel.py | 4 +- tests/test_rotary.py | 187 +++++++- 6 files changed, 601 insertions(+), 231 deletions(-) create mode 100644 flash_attn/ops/triton/rotary.py diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 4fb58d0..64ae7de 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -1,11 +1,11 @@ # Copyright (c) 2023, Tri Dao. import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Union -import rotary_emb import torch from einops import rearrange, repeat +from flash_attn.ops.triton.rotary import apply_rotary def rotate_half(x, interleaved=False): @@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False): def apply_rotary_emb_torch(x, cos, sin, interleaved=False): """ x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] - cos = repeat(cos, "s d -> s 1 (2 d)") - sin = repeat(sin, "s d -> s 1 (2 d)") + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1, @@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False): class ApplyRotaryEmb(torch.autograd.Function): @staticmethod - def forward(ctx, x, cos, sin, interleaved=False, inplace=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - batch, seqlen, nheads, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) - out = torch.empty_like(x) if not inplace else x - out_ro = out[..., :rotary_dim] - if inplace: - o1, o2 = x1, x2 - else: - o1, o2 = ( - out_ro.chunk(2, dim=-1) - if not interleaved - else (out_ro[..., ::2], out_ro[..., 1::2]) - ) - rotary_emb.apply_rotary( - x1, - x2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - o1, - o2, - False, + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): + out = apply_rotary( + x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace ) - if not inplace and rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None ctx.interleaved = interleaved ctx.inplace = inplace return out if not inplace else x @staticmethod def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - inplace = ctx.inplace - do_ro = do[..., :rotary_dim] - do1, do2 = ( - do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) - ) - dx = torch.empty_like(do) if not inplace else do - if inplace: - dx1, dx2 = do1, do2 + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors else: - dx_ro = dx[..., :rotary_dim] - dx1, dx2 = ( - dx_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (dx_ro[..., ::2], dx_ro[..., 1::2]) - ) - rotary_emb.apply_rotary( - do1, - do2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - dx1, - dx2, - True, + cos, sin = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, ) - if not inplace and rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None + return dx, None, None, None, None, None -apply_rotary_emb_func = ApplyRotaryEmb.apply +def apply_rotary_emb( + x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0 +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + out: (batch_size, seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb class ApplyRotaryEmbQKV_(torch.autograd.Function): @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - """ - qkv: (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): batch, seqlen, three, nheads, headdim = qkv.shape assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q_ro = qkv[:, :, 0, :, :rotary_dim] - q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) - rotary_emb.apply_rotary( - q1, - q2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - q1, - q2, - False, - ) - k_ro = qkv[:, :, 1, :, :rotary_dim] - k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) - rotary_emb.apply_rotary( - k1, - k2, - rearrange(cos_k[:seqlen], "s d -> s 1 d"), - rearrange(sin_k[:seqlen], "s d -> s 1 d"), - k1, - k2, - False, - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + q, k = qkv[:, :, 0], qkv[:, :, 1] + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None ctx.interleaved = interleaved return qkv @staticmethod def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - _, seqlen, _, _, headdim = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = ( - dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2]) - ) - rotary_emb.apply_rotary( - dq1, - dq2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - dq1, - dq2, - True, - ) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = ( - dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) - ) - rotary_emb.apply_rotary( - dk1, - dk2, - rearrange(cos_k[:seqlen], "s d -> s 1 d"), - rearrange(sin_k[:seqlen], "s d -> s 1 d"), - dk1, - dk2, - True, - ) - return dqkv, None, None, None, None, None + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + apply_rotary( + dq, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True, conjugate=True + ) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=interleaved, + inplace=True, + conjudate=True, + ) + return dqkv, None, None, None, None, None, None -apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply +def apply_rotary_emb_qkv_( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + qkv: (batch_size, seqlen, 3, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of Q and K. + """ + return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) class ApplyRotaryEmbKV_(torch.autograd.Function): @staticmethod - def forward(ctx, kv, cos, sin, interleaved=False): - """ - kv: (batch_size, seqlen, 2, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of k. - """ + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): batch, seqlen, two, nheads, headdim = kv.shape assert two == 2 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - k_ro = kv[:, :, 0, :, :rotary_dim] - k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) - rotary_emb.apply_rotary( - k1, - k2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - k1, - k2, - False, - ) # conj=False since this is the forward pass - ctx.save_for_backward(cos, sin) + k = kv[:, :, 0] + apply_rotary( + k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None ctx.interleaved = interleaved return kv @staticmethod def backward(ctx, dkv): - cos, sin = ctx.saved_tensors - _, seqlen, _, _, headdim = dkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dk_ro = dkv[:, :, 0, :, :rotary_dim] - dk1, dk2 = ( - dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors + else: + cos, sin = ctx.saved_tensors + apply_rotary( + dkv[:, :, 0], + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, ) - rotary_emb.apply_rotary( - dk1, - dk2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - dk1, - dk2, - True, - ) # conj=True since this is the backward pass - return dkv, None, None, None + return dkv, None, None, None, None apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply +def apply_rotary_emb_kv_( + kv, + cos, + sin, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + kv: (batch_size, seqlen, 2, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of K. + """ + return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) + + class RotaryEmbedding(torch.nn.Module): """ The rotary position embeddings from RoFormer_ (Su et. al). @@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module): self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward( - self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0 + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim) kv: (batch, seqlen, 2, nheads, headdim) - seqlen_offset: can be used in generation where the qkv being passed in is only the last - token in the batch. + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. """ seqlen = qkv.shape[1] - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + elif max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) if kv is None: if self.scale is None: return apply_rotary_emb_qkv_( qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - None, - None, - self.interleaved, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, ) else: return apply_rotary_emb_qkv_( qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - self._cos_k_cached[seqlen_offset:], - self._sin_k_cached[seqlen_offset:], - self.interleaved, + self._cos_cached, + self._sin_cached, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, ) else: q = qkv q = apply_rotary_emb_func( q, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - self.interleaved, - True, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offset, ) if self.scale is None: kv = apply_rotary_emb_kv_( kv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - self.interleaved, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, ) else: kv = apply_rotary_emb_kv_( kv, - self._cos_k_cached[seqlen_offset:], - self._sin_k_cached[seqlen_offset:], - self.interleaved, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, ) return q, kv diff --git a/flash_attn/models/gpt_neox.py b/flash_attn/models/gpt_neox.py index 5476380..3a8fa07 100644 --- a/flash_attn/models/gpt_neox.py +++ b/flash_attn/models/gpt_neox.py @@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config): # We don't store these biases state_dict.pop(f"transformer.layers.{l}.attention.bias") state_dict.pop(f"transformer.layers.{l}.attention.masked_bias") + # We don't store these + state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None) # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim) headdim = config.hidden_size // config.num_attention_heads @@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config): r"transformer.layers.\1.mixer.out_proj.", key, ) - key = re.sub( - r"^transformer.layers.(\d+).attention.rotary_emb.", - r"transformer.layers.\1.mixer.rotary_emb.", - key, - ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) diff --git a/flash_attn/ops/triton/linear.py b/flash_attn/ops/triton/linear.py index 0eb6f02..a8966db 100644 --- a/flash_attn/ops/triton/linear.py +++ b/flash_attn/ops/triton/linear.py @@ -1,12 +1,10 @@ -# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py +# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py # and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py from typing import Optional import torch import triton import triton.language as tl -from torch.autograd.function import FunctionCtx -from torch.cuda.amp import custom_fwd from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from flash_attn.ops.triton.k_activations import ( diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py new file mode 100644 index 0000000..ee2e93b --- /dev/null +++ b/flash_attn/ops/triton/rotary.py @@ -0,0 +1,182 @@ +from typing import Union + +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 2}), +# triton.Config({"BLOCK_M": 4}), +# triton.Config({"BLOCK_M": 8}), +# triton.Config({"BLOCK_M": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"] +# ) +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk = tl.arange(0, BLOCK_K // 2) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + X = X + ( + pid_batch * stride_x_batch + + rm[:, None] * stride_x_seqlen + + pid_head * stride_x_nheads + + rk[None, :] * stride_x_headdim * (2 if INTERLEAVED else 1) + ) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk[None, :]) + + cos = tl.load( + COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), other=0.0).to( + tl.float32 + ) + x1 = tl.load( + X + stride_x_headdim * (1 if INTERLEAVED else rotary_dim_half), + mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if not CONJUGATE: + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + else: + o0 = x0 * cos + x1 * sin + o1 = -x0 * sin + x1 * cos + + # write back result + OUT = OUT + ( + pid_batch * stride_out_batch + + rm[:, None] * stride_out_seqlen + + pid_head * stride_out_nheads + + rk[None, :] * stride_out_headdim * (2 if INTERLEAVED else 1) + ) + tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half)) + tl.store( + OUT + stride_out_headdim * (1 if INTERLEAVED else rotary_dim_half), + o1, + mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim_half), + ) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + Returns: + y: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 + if rotary_dim <= 32 + else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + ) + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0), # strides + output.stride(1), + output.stride(2), + output.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + interleaved, + conjugate, + BLOCK_M, + ) + return output diff --git a/tests/models/test_gpt_generation_parallel.py b/tests/models/test_gpt_generation_parallel.py index 1a7e735..3fbc5da 100644 --- a/tests/models/test_gpt_generation_parallel.py +++ b/tests/models/test_gpt_generation_parallel.py @@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ) print(out_cg.sequences) + parallel_state.destroy_model_parallel() + if not rotary: out_hf = model_hf.generate( input_ids=input_ids, @@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ).abs().max().item() < 3 * ( torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) ).abs().max().item() - - parallel_state.destroy_model_parallel() diff --git a/tests/test_rotary.py b/tests/test_rotary.py index ab9c3bb..f312d1c 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -1,10 +1,12 @@ import math +import random import pytest import torch import torch.nn.functional as F from einops import rearrange -from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch +from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch +from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) @@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) # @pytest.mark.parametrize('dtype', ([torch.float16])) +@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) +# @pytest.mark.parametrize("seqlen_offsets_type", [0]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) -# @pytest.mark.parametrize('rotary_fraction', [0.5]) +# @pytest.mark.parametrize('rotary_fraction', [1.0]) +@pytest.mark.parametrize("interleaved", [False, True]) +# @pytest.mark.parametrize('interleaved', [False]) @pytest.mark.parametrize("inplace", [False, True]) # @pytest.mark.parametrize('inplace', [False]) -def test_rotary_single_tensor(inplace, rotary_fraction, dtype): +def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype): rtol = 1e-3 batch_size = 32 nheads = 4 seqlen = 217 headdim = 128 + device = "cuda" + torch.manual_seed(42) x = torch.randn( - batch_size, seqlen, nheads, headdim, dtype=dtype, device="cuda", requires_grad=True + batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True ) x_pt = x.detach().clone().requires_grad_() rotary_dim = int(rotary_fraction * headdim) assert rotary_dim % 2 == 0 - angle = torch.randn(seqlen, rotary_dim // 2, device="cuda") + angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) - out = apply_rotary_emb_func(x, cos, sin, inplace) - out_pt = apply_rotary_emb_torch(x_pt, cos, sin) - # Numerical error if we just do any arithmetic - atol = ((out + 0.3 - 0.3) - out).abs().max().item() - assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) + if seqlen_offsets_type == 0: + seqlen_offsets = 0 + elif seqlen_offsets_type is int: + seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item() + elif seqlen_offsets_type is torch.Tensor: + seqlen_offsets = torch.randint( + 0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device + ) + out = apply_rotary_emb( + x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace + ) + if seqlen_offsets_type is torch.Tensor: + arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") + idx = rearrange(seqlen_offsets, "b -> b 1") + arange + cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) + sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) + else: + cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] + sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] + out_pt = apply_rotary_emb_torch( + x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved + ).to(dtype=dtype) + print(f"Output max diff: {(out - out_pt).abs().max().item()}") + g = torch.randn_like(out) g_pt = g.clone() # If inplace=True, we might modify the gradient inplace out.backward(g) out_pt.backward(g_pt) + print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}") + + if not inplace: + assert torch.equal(x, x_pt) + # Numerical error if we just do any arithmetic + atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() + assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol) + + +@pytest.mark.parametrize( + "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) +) +# @pytest.mark.parametrize('dtype', ([torch.float16])) +@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) +# @pytest.mark.parametrize("seqlen_offsets_type", [0]) +@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) +# @pytest.mark.parametrize('rotary_fraction', [1.0]) +@pytest.mark.parametrize("interleaved", [False, True]) +# @pytest.mark.parametrize('interleaved', [False]) +def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype): + rtol = 1e-3 + batch_size = 32 + nheads = 4 + seqlen = 512 + headdim = 128 + device = "cuda" + torch.manual_seed(42) + qkv = torch.randn( + batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True + ) + qkv_pt = qkv.detach().clone().requires_grad_() + rotary_dim = int(rotary_fraction * headdim) + assert rotary_dim % 2 == 0 + angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + if seqlen_offsets_type == 0: + seqlen_offsets = 0 + elif seqlen_offsets_type is int: + seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item() + elif seqlen_offsets_type is torch.Tensor: + seqlen_offsets = torch.randint( + 0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device + ) + out = apply_rotary_emb_qkv_( + qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved + ) + if seqlen_offsets_type is torch.Tensor: + arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") + idx = rearrange(seqlen_offsets, "b -> b 1") + arange + cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) + sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) + else: + cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] + sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] + q_pt = apply_rotary_emb_torch( + qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved + ).to(dtype=dtype) + k_pt = apply_rotary_emb_torch( + qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved + ).to(dtype=dtype) + out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2) + print(f"Output max diff: {(out - out_pt).abs().max().item()}") + + g = torch.randn_like(out) + g_pt = g.clone() # Since inplace=True, we modify the gradient inplace + out.backward(g) + out_pt.backward(g_pt) + print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}") + + # Numerical error if we just do any arithmetic + atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() + assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) + atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item() + assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol) + + +@pytest.mark.parametrize( + "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) +) +# @pytest.mark.parametrize('dtype', ([torch.float16])) +@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) +# @pytest.mark.parametrize("seqlen_offsets_type", [0]) +@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) +# @pytest.mark.parametrize('rotary_fraction', [1.0]) +@pytest.mark.parametrize("interleaved", [False, True]) +# @pytest.mark.parametrize('interleaved', [False]) +def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype): + rtol = 1e-3 + batch_size = 32 + nheads = 4 + seqlen = 781 + headdim = 64 + device = "cuda" + torch.manual_seed(42) + kv = torch.randn( + batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True + ) + kv_pt = kv.detach().clone().requires_grad_() + rotary_dim = int(rotary_fraction * headdim) + assert rotary_dim % 2 == 0 + angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + if seqlen_offsets_type == 0: + seqlen_offsets = 0 + elif seqlen_offsets_type is int: + seqlen_offsets = torch.randint(0, seqlen + 1, (1, )).item() + elif seqlen_offsets_type is torch.Tensor: + seqlen_offsets = torch.randint( + 0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device + ) + out = apply_rotary_emb_kv_( + kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved + ) + if seqlen_offsets_type is torch.Tensor: + arange = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") + idx = rearrange(seqlen_offsets, "b -> b 1") + arange + cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) + sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) + else: + cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] + sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] + k_pt = apply_rotary_emb_torch( + kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved + ).to(dtype=dtype) + out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2) + print(f"Output max diff: {(out - out_pt).abs().max().item()}") + + g = torch.randn_like(out) + g_pt = g.clone() # Since inplace=True, we modify the gradient inplace + out.backward(g) + out_pt.backward(g_pt) + print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}") + + # Numerical error if we just do any arithmetic + atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() + assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) + atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item() + assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)