From ac3b684cdb2d3262a0dff2b7952ee1519ffcdc9c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 17 Apr 2023 22:34:05 -0700 Subject: [PATCH] Have a separate nn.Dropout module in SelfAttention module --- flash_attn/modules/mha.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index bc227f0..528bb3c 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -55,7 +55,7 @@ class FlashSelfAttention(nn.Module): assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed' self.causal = causal self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout + self.drop = nn.Dropout(attention_dropout) self.triton = triton def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): @@ -84,13 +84,13 @@ class FlashSelfAttention(nn.Module): assert max_seqlen is not None assert isinstance(max_seqlen, int) return flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, + qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) else: batch_size, seqlen = qkv.shape[0], qkv.shape[1] # Triton version doesn't support dropout - if self.triton and (self.dropout_p == 0 or not self.training): + if self.triton and (self.drop.p == 0 or not self.training): output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale) else: qkv = rearrange(qkv, 'b s ... -> (b s) ...') @@ -98,7 +98,7 @@ class FlashSelfAttention(nn.Module): cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, + qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) @@ -124,7 +124,7 @@ class FlashCrossAttention(nn.Module): assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed' self.causal = causal self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout + self.drop = nn.Dropout(attention_dropout) self.triton = triton def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, @@ -156,14 +156,14 @@ class FlashCrossAttention(nn.Module): assert isinstance(max_seqlen, int) return flash_attn_unpadded_kvpacked_func( q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, - self.dropout_p if self.training else 0.0, + self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) else: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] - if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout + if self.triton and (self.drop.p == 0.0 or not self.training): # Triton version doesn't support dropout output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale) else: q = rearrange(q, 'b s ... -> (b s) ...') @@ -174,7 +174,7 @@ class FlashCrossAttention(nn.Module): dtype=torch.int32, device=kv.device) output = flash_attn_unpadded_kvpacked_func( q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, - self.dropout_p if self.training else 0.0, + self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) @@ -195,7 +195,7 @@ class SelfAttention(nn.Module): super().__init__() self.causal = causal self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout + self.drop = nn.Dropout(attention_dropout) def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. @@ -224,7 +224,7 @@ class SelfAttention(nn.Module): # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) + attention_drop = self.drop(attention) output = torch.einsum('bhts,bshd->bthd', attention_drop, v) return output @@ -243,7 +243,7 @@ class CrossAttention(nn.Module): super().__init__() self.causal = causal self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout + self.drop = nn.Dropout(attention_dropout) def forward(self, q, kv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. @@ -276,7 +276,7 @@ class CrossAttention(nn.Module): # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) + attention_drop = self.drop(attention) output = torch.einsum('bhts,bshd->bthd', attention_drop, v) return output