Fix typo in RotaryEmbedding forward output type (#666)

This commit is contained in:
Antony Frolov 2023-11-09 22:43:02 +03:00 committed by GitHub
parent 83aef842be
commit 3566596ad8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -417,7 +417,7 @@ class RotaryEmbedding(torch.nn.Module):
kv: Optional[torch.Tensor] = None,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)