Fix: implement deterministic backward in mha (#748)
* fix deterministic * fix deterministic
This commit is contained in:
parent
1a2c3e8c25
commit
386e391117
@ -61,7 +61,7 @@ class FlashSelfAttention(nn.Module):
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||
@ -69,6 +69,7 @@ class FlashSelfAttention(nn.Module):
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
self.deterministic = deterministic
|
||||
|
||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
@ -103,6 +104,7 @@ class FlashSelfAttention(nn.Module):
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
else:
|
||||
return flash_attn_qkvpacked_func(
|
||||
@ -111,6 +113,7 @@ class FlashSelfAttention(nn.Module):
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
|
||||
|
||||
@ -125,7 +128,7 @@ class FlashCrossAttention(nn.Module):
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
||||
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
||||
@ -133,6 +136,7 @@ class FlashCrossAttention(nn.Module):
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
self.deterministic = deterministic
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -180,6 +184,7 @@ class FlashCrossAttention(nn.Module):
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
@ -192,6 +197,7 @@ class FlashCrossAttention(nn.Module):
|
||||
causal=causal,
|
||||
softmax_scale=self.softmax_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
deterministic=self.deterministic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user