Add window_size option to MHA and GPT
This commit is contained in:
parent
dc72d960a7
commit
ef0ed10622
@ -78,6 +78,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
|
||||
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
|
||||
use_alibi = getattr(config, "use_alibi", False)
|
||||
window_size = getattr(config, "window_size", (-1, -1))
|
||||
use_flash_attn = getattr(config, "use_flash_attn", False)
|
||||
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
||||
if not fused_bias_fc:
|
||||
@ -110,6 +111,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
rotary_emb_interleaved=rotary_emb_interleaved,
|
||||
use_alibi=use_alibi,
|
||||
window_size=window_size,
|
||||
use_flash_attn=use_flash_attn,
|
||||
**serial_kwargs,
|
||||
**parallel_kwargs,
|
||||
|
||||
@ -61,7 +61,15 @@ class FlashSelfAttention(nn.Module):
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
|
||||
def __init__(
|
||||
self,
|
||||
causal=False,
|
||||
softmax_scale=None,
|
||||
attention_dropout=0.0,
|
||||
window_size=(-1, -1),
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||
@ -69,6 +77,7 @@ class FlashSelfAttention(nn.Module):
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
self.window_size = window_size
|
||||
self.deterministic = deterministic
|
||||
|
||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||
@ -104,6 +113,7 @@ class FlashSelfAttention(nn.Module):
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.window_size,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
else:
|
||||
@ -113,6 +123,7 @@ class FlashSelfAttention(nn.Module):
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.window_size,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
|
||||
@ -128,7 +139,15 @@ class FlashCrossAttention(nn.Module):
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
|
||||
def __init__(
|
||||
self,
|
||||
causal=False,
|
||||
softmax_scale=None,
|
||||
attention_dropout=0.0,
|
||||
alibi_slopes=None,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
||||
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
||||
@ -136,6 +155,7 @@ class FlashCrossAttention(nn.Module):
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
self.window_size = window_size
|
||||
self.deterministic = deterministic
|
||||
|
||||
def forward(
|
||||
@ -184,6 +204,7 @@ class FlashCrossAttention(nn.Module):
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.window_size,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
else:
|
||||
@ -197,6 +218,7 @@ class FlashCrossAttention(nn.Module):
|
||||
causal=causal,
|
||||
softmax_scale=self.softmax_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.window_size,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
|
||||
@ -372,6 +394,7 @@ class MHA(nn.Module):
|
||||
rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False,
|
||||
use_alibi=False,
|
||||
window_size=(-1, -1),
|
||||
fused_bias_fc=False,
|
||||
use_flash_attn=False,
|
||||
return_residual=False,
|
||||
@ -401,6 +424,8 @@ class MHA(nn.Module):
|
||||
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
if window_size != (-1, -1):
|
||||
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||
@ -431,12 +456,12 @@ class MHA(nn.Module):
|
||||
)
|
||||
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
||||
inner_attn_cls = (
|
||||
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
||||
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
||||
if use_flash_attn
|
||||
else SelfAttention
|
||||
)
|
||||
inner_cross_attn_cls = (
|
||||
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
||||
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
||||
if use_flash_attn
|
||||
else CrossAttention
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user