Add window_size option to MHA and GPT

This commit is contained in:
Tri Dao 2024-01-31 02:42:23 -08:00
parent dc72d960a7
commit ef0ed10622
2 changed files with 31 additions and 4 deletions

View File

@ -78,6 +78,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
use_alibi = getattr(config, "use_alibi", False)
window_size = getattr(config, "window_size", (-1, -1))
use_flash_attn = getattr(config, "use_flash_attn", False)
fused_bias_fc = getattr(config, "fused_bias_fc", False)
if not fused_bias_fc:
@ -110,6 +111,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved,
use_alibi=use_alibi,
window_size=window_size,
use_flash_attn=use_flash_attn,
**serial_kwargs,
**parallel_kwargs,

View File

@ -61,7 +61,15 @@ class FlashSelfAttention(nn.Module):
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
):
super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
@ -69,6 +77,7 @@ class FlashSelfAttention(nn.Module):
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.window_size = window_size
self.deterministic = deterministic
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
@ -104,6 +113,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
else:
@ -113,6 +123,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
@ -128,7 +139,15 @@ class FlashCrossAttention(nn.Module):
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
alibi_slopes=None,
window_size=(-1, -1),
deterministic=False,
):
super().__init__()
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
@ -136,6 +155,7 @@ class FlashCrossAttention(nn.Module):
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.window_size = window_size
self.deterministic = deterministic
def forward(
@ -184,6 +204,7 @@ class FlashCrossAttention(nn.Module):
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
else:
@ -197,6 +218,7 @@ class FlashCrossAttention(nn.Module):
causal=causal,
softmax_scale=self.softmax_scale,
alibi_slopes=self.alibi_slopes,
window_size=self.window_size,
deterministic=self.deterministic,
)
@ -372,6 +394,7 @@ class MHA(nn.Module):
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
window_size=(-1, -1),
fused_bias_fc=False,
use_flash_attn=False,
return_residual=False,
@ -401,6 +424,8 @@ class MHA(nn.Module):
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
else:
alibi_slopes = None
if window_size != (-1, -1):
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
@ -431,12 +456,12 @@ class MHA(nn.Module):
)
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else CrossAttention
)