From ef0ed106225eb00af200f524603bc09557330386 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 31 Jan 2024 02:42:23 -0800 Subject: [PATCH] Add window_size option to MHA and GPT --- flash_attn/models/gpt.py | 2 ++ flash_attn/modules/mha.py | 33 +++++++++++++++++++++++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 6d4b6b1..71540da 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -78,6 +78,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) use_alibi = getattr(config, "use_alibi", False) + window_size = getattr(config, "window_size", (-1, -1)) use_flash_attn = getattr(config, "use_flash_attn", False) fused_bias_fc = getattr(config, "fused_bias_fc", False) if not fused_bias_fc: @@ -110,6 +111,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_interleaved=rotary_emb_interleaved, use_alibi=use_alibi, + window_size=window_size, use_flash_attn=use_flash_attn, **serial_kwargs, **parallel_kwargs, diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 144aa48..dd1c62e 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -61,7 +61,15 @@ class FlashSelfAttention(nn.Module): (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False): + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + ): super().__init__() assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" @@ -69,6 +77,7 @@ class FlashSelfAttention(nn.Module): self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.window_size = window_size self.deterministic = deterministic def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): @@ -104,6 +113,7 @@ class FlashSelfAttention(nn.Module): softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, + window_size=self.window_size, deterministic=self.deterministic, ) else: @@ -113,6 +123,7 @@ class FlashSelfAttention(nn.Module): softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, + window_size=self.window_size, deterministic=self.deterministic, ) @@ -128,7 +139,15 @@ class FlashCrossAttention(nn.Module): (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False): + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + alibi_slopes=None, + window_size=(-1, -1), + deterministic=False, + ): super().__init__() assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" @@ -136,6 +155,7 @@ class FlashCrossAttention(nn.Module): self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.window_size = window_size self.deterministic = deterministic def forward( @@ -184,6 +204,7 @@ class FlashCrossAttention(nn.Module): softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, + window_size=self.window_size, deterministic=self.deterministic, ) else: @@ -197,6 +218,7 @@ class FlashCrossAttention(nn.Module): causal=causal, softmax_scale=self.softmax_scale, alibi_slopes=self.alibi_slopes, + window_size=self.window_size, deterministic=self.deterministic, ) @@ -372,6 +394,7 @@ class MHA(nn.Module): rotary_emb_scale_base=None, rotary_emb_interleaved=False, use_alibi=False, + window_size=(-1, -1), fused_bias_fc=False, use_flash_attn=False, return_residual=False, @@ -401,6 +424,8 @@ class MHA(nn.Module): alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) else: alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" self.num_heads = num_heads self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads @@ -431,12 +456,12 @@ class MHA(nn.Module): ) wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls inner_attn_cls = ( - partial(FlashSelfAttention, alibi_slopes=alibi_slopes) + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else SelfAttention ) inner_cross_attn_cls = ( - partial(FlashCrossAttention, alibi_slopes=alibi_slopes) + partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else CrossAttention )