From 386e391117b8aced37b76f4e73bf819b8c1bbe22 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:13:56 +0800 Subject: [PATCH] Fix: implement deterministic backward in mha (#748) * fix deterministic * fix deterministic --- flash_attn/modules/mha.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 16c245c..144aa48 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -61,7 +61,7 @@ class FlashSelfAttention(nn.Module): (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False): super().__init__() assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" @@ -69,6 +69,7 @@ class FlashSelfAttention(nn.Module): self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.deterministic = deterministic def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """Implements the multihead softmax attention. @@ -103,6 +104,7 @@ class FlashSelfAttention(nn.Module): softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, + deterministic=self.deterministic, ) else: return flash_attn_qkvpacked_func( @@ -111,6 +113,7 @@ class FlashSelfAttention(nn.Module): softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, + deterministic=self.deterministic, ) @@ -125,7 +128,7 @@ class FlashCrossAttention(nn.Module): (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False): super().__init__() assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" @@ -133,6 +136,7 @@ class FlashCrossAttention(nn.Module): self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.deterministic = deterministic def forward( self, @@ -180,6 +184,7 @@ class FlashCrossAttention(nn.Module): softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, + deterministic=self.deterministic, ) else: batch_size, seqlen_q = q.shape[0], q.shape[1] @@ -192,6 +197,7 @@ class FlashCrossAttention(nn.Module): causal=causal, softmax_scale=self.softmax_scale, alibi_slopes=self.alibi_slopes, + deterministic=self.deterministic, )