diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 4ec049e..215f518 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -417,7 +417,7 @@ class RotaryEmbedding(torch.nn.Module): kv: Optional[torch.Tensor] = None, seqlen_offset: Union[int, torch.Tensor] = 0, max_seqlen: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim)