Have a separate nn.Dropout module in SelfAttention module
This commit is contained in:
parent
df1344f866
commit
ac3b684cdb
@ -55,7 +55,7 @@ class FlashSelfAttention(nn.Module):
|
|||||||
assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
|
assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.dropout_p = attention_dropout
|
self.drop = nn.Dropout(attention_dropout)
|
||||||
self.triton = triton
|
self.triton = triton
|
||||||
|
|
||||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||||
@ -84,13 +84,13 @@ class FlashSelfAttention(nn.Module):
|
|||||||
assert max_seqlen is not None
|
assert max_seqlen is not None
|
||||||
assert isinstance(max_seqlen, int)
|
assert isinstance(max_seqlen, int)
|
||||||
return flash_attn_unpadded_qkvpacked_func(
|
return flash_attn_unpadded_qkvpacked_func(
|
||||||
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
|
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
||||||
softmax_scale=self.softmax_scale, causal=causal
|
softmax_scale=self.softmax_scale, causal=causal
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||||
# Triton version doesn't support dropout
|
# Triton version doesn't support dropout
|
||||||
if self.triton and (self.dropout_p == 0 or not self.training):
|
if self.triton and (self.drop.p == 0 or not self.training):
|
||||||
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
|
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
|
||||||
else:
|
else:
|
||||||
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
||||||
@ -98,7 +98,7 @@ class FlashSelfAttention(nn.Module):
|
|||||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||||
device=qkv.device)
|
device=qkv.device)
|
||||||
output = flash_attn_unpadded_qkvpacked_func(
|
output = flash_attn_unpadded_qkvpacked_func(
|
||||||
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
|
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
||||||
softmax_scale=self.softmax_scale, causal=causal
|
softmax_scale=self.softmax_scale, causal=causal
|
||||||
)
|
)
|
||||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||||
@ -124,7 +124,7 @@ class FlashCrossAttention(nn.Module):
|
|||||||
assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
|
assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.dropout_p = attention_dropout
|
self.drop = nn.Dropout(attention_dropout)
|
||||||
self.triton = triton
|
self.triton = triton
|
||||||
|
|
||||||
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
|
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
|
||||||
@ -156,14 +156,14 @@ class FlashCrossAttention(nn.Module):
|
|||||||
assert isinstance(max_seqlen, int)
|
assert isinstance(max_seqlen, int)
|
||||||
return flash_attn_unpadded_kvpacked_func(
|
return flash_attn_unpadded_kvpacked_func(
|
||||||
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
|
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
|
||||||
self.dropout_p if self.training else 0.0,
|
self.drop.p if self.training else 0.0,
|
||||||
softmax_scale=self.softmax_scale, causal=causal
|
softmax_scale=self.softmax_scale, causal=causal
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||||
seqlen_k = kv.shape[1]
|
seqlen_k = kv.shape[1]
|
||||||
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
||||||
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
|
if self.triton and (self.drop.p == 0.0 or not self.training): # Triton version doesn't support dropout
|
||||||
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
|
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
|
||||||
else:
|
else:
|
||||||
q = rearrange(q, 'b s ... -> (b s) ...')
|
q = rearrange(q, 'b s ... -> (b s) ...')
|
||||||
@ -174,7 +174,7 @@ class FlashCrossAttention(nn.Module):
|
|||||||
dtype=torch.int32, device=kv.device)
|
dtype=torch.int32, device=kv.device)
|
||||||
output = flash_attn_unpadded_kvpacked_func(
|
output = flash_attn_unpadded_kvpacked_func(
|
||||||
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
||||||
self.dropout_p if self.training else 0.0,
|
self.drop.p if self.training else 0.0,
|
||||||
softmax_scale=self.softmax_scale, causal=causal
|
softmax_scale=self.softmax_scale, causal=causal
|
||||||
)
|
)
|
||||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||||
@ -195,7 +195,7 @@ class SelfAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.dropout_p = attention_dropout
|
self.drop = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
def forward(self, qkv, causal=None, key_padding_mask=None):
|
def forward(self, qkv, causal=None, key_padding_mask=None):
|
||||||
"""Implements the multihead softmax attention.
|
"""Implements the multihead softmax attention.
|
||||||
@ -224,7 +224,7 @@ class SelfAttention(nn.Module):
|
|||||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||||
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
|
attention_drop = self.drop(attention)
|
||||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -243,7 +243,7 @@ class CrossAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.dropout_p = attention_dropout
|
self.drop = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
||||||
"""Implements the multihead softmax attention.
|
"""Implements the multihead softmax attention.
|
||||||
@ -276,7 +276,7 @@ class CrossAttention(nn.Module):
|
|||||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||||
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
|
attention_drop = self.drop(attention)
|
||||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user