From 3566596ad867ee415dd3c12616dd50c610176f6c Mon Sep 17 00:00:00 2001 From: Antony Frolov Date: Thu, 9 Nov 2023 22:43:02 +0300 Subject: [PATCH] Fix typo in RotaryEmbedding forward output type (#666) --- flash_attn/layers/rotary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 4ec049e..215f518 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -417,7 +417,7 @@ class RotaryEmbedding(torch.nn.Module): kv: Optional[torch.Tensor] = None, seqlen_offset: Union[int, torch.Tensor] = 0, max_seqlen: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim)