diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 64ae7de..5f68a84 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -402,10 +402,10 @@ class RotaryEmbedding(torch.nn.Module): Apply rotary embedding *inplace* to qkv and / or kv. """ seqlen = qkv.shape[1] - if isinstance(seqlen_offset, int): - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) - elif max_seqlen is not None: + if max_seqlen is not None: self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if kv is None: if self.scale is None: return apply_rotary_emb_qkv_( diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 3009fd1..bf34b47 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -606,6 +606,9 @@ class MHA(nn.Module): else {"key_padding_mask": key_padding_mask, **kwargs} ) seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset + rotary_max_seqlen = ( + inference_params.max_sequene_len if inference_params is not None else None + ) if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None if not self.return_residual: @@ -623,7 +626,9 @@ class MHA(nn.Module): or not inference_params.fused_ft_kernel ): if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) if inference_params is None: if not self.checkpointing: context = self.inner_attn(qkv, **kwargs) @@ -669,7 +674,9 @@ class MHA(nn.Module): or not inference_params.fused_ft_kernel ): if self.rotary_emb_dim > 0: - q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset) + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) if inference_params is None: if not self.checkpointing: context = self.inner_cross_attn(q, kv, **kwargs) @@ -851,6 +858,9 @@ class ParallelMHA(nn.Module): if seqlen is not None: qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset + rotary_max_seqlen = ( + inference_params.max_sequene_len if inference_params is not None else None + ) if self.num_heads_kv == self.num_heads: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) if ( @@ -859,7 +869,9 @@ class ParallelMHA(nn.Module): or not inference_params.fused_ft_kernel ): if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) if inference_params is None: if not self.checkpointing: context = self.inner_attn(qkv, **kwargs) @@ -889,7 +901,9 @@ class ParallelMHA(nn.Module): or not inference_params.fused_ft_kernel ): if self.rotary_emb_dim > 0: - q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset) + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) if inference_params is None: if not self.checkpointing: context = self.inner_cross_attn(q, kv, **kwargs)