Have a separate nn.Dropout module in SelfAttention module
This commit is contained in:
parent
df1344f866
commit
ac3b684cdb
@ -55,7 +55,7 @@ class FlashSelfAttention(nn.Module):
|
||||
assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.triton = triton
|
||||
|
||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||
@ -84,13 +84,13 @@ class FlashSelfAttention(nn.Module):
|
||||
assert max_seqlen is not None
|
||||
assert isinstance(max_seqlen, int)
|
||||
return flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
|
||||
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||
# Triton version doesn't support dropout
|
||||
if self.triton and (self.dropout_p == 0 or not self.training):
|
||||
if self.triton and (self.drop.p == 0 or not self.training):
|
||||
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
|
||||
else:
|
||||
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
||||
@ -98,7 +98,7 @@ class FlashSelfAttention(nn.Module):
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
|
||||
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
@ -124,7 +124,7 @@ class FlashCrossAttention(nn.Module):
|
||||
assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.triton = triton
|
||||
|
||||
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
|
||||
@ -156,14 +156,14 @@ class FlashCrossAttention(nn.Module):
|
||||
assert isinstance(max_seqlen, int)
|
||||
return flash_attn_unpadded_kvpacked_func(
|
||||
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = kv.shape[1]
|
||||
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
||||
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
|
||||
if self.triton and (self.drop.p == 0.0 or not self.training): # Triton version doesn't support dropout
|
||||
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
|
||||
else:
|
||||
q = rearrange(q, 'b s ... -> (b s) ...')
|
||||
@ -174,7 +174,7 @@ class FlashCrossAttention(nn.Module):
|
||||
dtype=torch.int32, device=kv.device)
|
||||
output = flash_attn_unpadded_kvpacked_func(
|
||||
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
@ -195,7 +195,7 @@ class SelfAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, qkv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
@ -224,7 +224,7 @@ class SelfAttention(nn.Module):
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
|
||||
attention_drop = self.drop(attention)
|
||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
return output
|
||||
|
||||
@ -243,7 +243,7 @@ class CrossAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
@ -276,7 +276,7 @@ class CrossAttention(nn.Module):
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
|
||||
attention_drop = self.drop(attention)
|
||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
return output
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user