Fix: implement deterministic backward in mha (#748)

* fix deterministic

* fix deterministic
This commit is contained in:
jiaxingli 2024-01-03 10:13:56 +08:00 committed by GitHub
parent 1a2c3e8c25
commit 386e391117
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -61,7 +61,7 @@ class FlashSelfAttention(nn.Module):
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
@ -69,6 +69,7 @@ class FlashSelfAttention(nn.Module):
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.deterministic = deterministic
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention.
@ -103,6 +104,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
)
else:
return flash_attn_qkvpacked_func(
@ -111,6 +113,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
)
@ -125,7 +128,7 @@ class FlashCrossAttention(nn.Module):
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
super().__init__()
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
@ -133,6 +136,7 @@ class FlashCrossAttention(nn.Module):
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
self.deterministic = deterministic
def forward(
self,
@ -180,6 +184,7 @@ class FlashCrossAttention(nn.Module):
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
)
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
@ -192,6 +197,7 @@ class FlashCrossAttention(nn.Module):
causal=causal,
softmax_scale=self.softmax_scale,
alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
)