Add window_size option to ParallelMHA

This commit is contained in:
Tri Dao 2024-02-04 11:42:34 -08:00
parent 2423cca3ad
commit a190df011c

View File

@ -747,6 +747,7 @@ class ParallelMHA(nn.Module):
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
window_size=(-1, -1),
use_flash_attn=False,
checkpointing=False,
sequence_parallel=True,
@ -793,6 +794,8 @@ class ParallelMHA(nn.Module):
)
else:
alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed"
@ -816,12 +819,12 @@ class ParallelMHA(nn.Module):
**factory_kwargs,
)
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else CrossAttention
)