Add window_size option to ParallelMHA
This commit is contained in:
parent
2423cca3ad
commit
a190df011c
@ -747,6 +747,7 @@ class ParallelMHA(nn.Module):
|
||||
rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False,
|
||||
use_alibi=False,
|
||||
window_size=(-1, -1),
|
||||
use_flash_attn=False,
|
||||
checkpointing=False,
|
||||
sequence_parallel=True,
|
||||
@ -793,6 +794,8 @@ class ParallelMHA(nn.Module):
|
||||
)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
if window_size != (-1, -1):
|
||||
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
||||
@ -816,12 +819,12 @@ class ParallelMHA(nn.Module):
|
||||
**factory_kwargs,
|
||||
)
|
||||
inner_attn_cls = (
|
||||
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
||||
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
||||
if use_flash_attn
|
||||
else SelfAttention
|
||||
)
|
||||
inner_cross_attn_cls = (
|
||||
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
||||
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
||||
if use_flash_attn
|
||||
else CrossAttention
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user