From a190df011c28c029357fc6346bb880466262672a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 4 Feb 2024 11:42:34 -0800 Subject: [PATCH] Add window_size option to ParallelMHA --- flash_attn/modules/mha.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index dd1c62e..89c7680 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -747,6 +747,7 @@ class ParallelMHA(nn.Module): rotary_emb_scale_base=None, rotary_emb_interleaved=False, use_alibi=False, + window_size=(-1, -1), use_flash_attn=False, checkpointing=False, sequence_parallel=True, @@ -793,6 +794,8 @@ class ParallelMHA(nn.Module): ) else: alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, "rotary_emb is not installed" @@ -816,12 +819,12 @@ class ParallelMHA(nn.Module): **factory_kwargs, ) inner_attn_cls = ( - partial(FlashSelfAttention, alibi_slopes=alibi_slopes) + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else SelfAttention ) inner_cross_attn_cls = ( - partial(FlashCrossAttention, alibi_slopes=alibi_slopes) + partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else CrossAttention )