Fix typo in RotaryEmbedding forward output type (#666)
This commit is contained in:
parent
83aef842be
commit
3566596ad8
@ -417,7 +417,7 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
seqlen_offset: Union[int, torch.Tensor] = 0,
|
||||
max_seqlen: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
||||
else it's just q of shape (batch, seqlen, nheads, headdim)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user