Have a separate nn.Dropout module in SelfAttention module

This commit is contained in:
Tri Dao 2023-04-17 22:34:05 -07:00
parent df1344f866
commit ac3b684cdb

View File

@ -55,7 +55,7 @@ class FlashSelfAttention(nn.Module):
assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.drop = nn.Dropout(attention_dropout)
self.triton = triton
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
@ -84,13 +84,13 @@ class FlashSelfAttention(nn.Module):
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
return flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
else:
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
# Triton version doesn't support dropout
if self.triton and (self.dropout_p == 0 or not self.training):
if self.triton and (self.drop.p == 0 or not self.training):
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
else:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
@ -98,7 +98,7 @@ class FlashSelfAttention(nn.Module):
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
@ -124,7 +124,7 @@ class FlashCrossAttention(nn.Module):
assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.drop = nn.Dropout(attention_dropout)
self.triton = triton
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
@ -156,14 +156,14 @@ class FlashCrossAttention(nn.Module):
assert isinstance(max_seqlen, int)
return flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
self.dropout_p if self.training else 0.0,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
if self.triton and (self.drop.p == 0.0 or not self.training): # Triton version doesn't support dropout
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
else:
q = rearrange(q, 'b s ... -> (b s) ...')
@ -174,7 +174,7 @@ class FlashCrossAttention(nn.Module):
dtype=torch.int32, device=kv.device)
output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0,
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
@ -195,7 +195,7 @@ class SelfAttention(nn.Module):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.drop = nn.Dropout(attention_dropout)
def forward(self, qkv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
@ -224,7 +224,7 @@ class SelfAttention(nn.Module):
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
attention_drop = self.drop(attention)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output
@ -243,7 +243,7 @@ class CrossAttention(nn.Module):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.drop = nn.Dropout(attention_dropout)
def forward(self, q, kv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
@ -276,7 +276,7 @@ class CrossAttention(nn.Module):
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
attention_drop = self.drop(attention)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output