[Block] Re-enable DropPath
This commit is contained in:
parent
9ee0ff1d9b
commit
6fc1e07da2
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
# from torchvision.ops import StochasticDepth
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp
|
||||
@ -70,12 +70,12 @@ class Block(nn.Module):
|
||||
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.dropout1 = dropout_cls(resid_dropout1)
|
||||
# self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||
self.norm1 = norm_cls(dim)
|
||||
self.mlp = mlp_cls(dim)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
self.dropout2 = dropout_cls(resid_dropout2)
|
||||
# self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
@ -129,14 +129,13 @@ class Block(nn.Module):
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
rowscale1 = None
|
||||
# if self.drop_path1.p == 0 or not self.training:
|
||||
# rowscale1 = None
|
||||
# else:
|
||||
# rowscale1 = self.drop_path1(torch.ones(
|
||||
# hidden_states.shape[:-1], device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype)
|
||||
# )
|
||||
if self.drop_path1.p == 0 or not self.training:
|
||||
rowscale1 = None
|
||||
else:
|
||||
rowscale1 = self.drop_path1(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm1.weight, self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
@ -157,14 +156,13 @@ class Block(nn.Module):
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
# if self.drop_path2.p == 0 or not self.training:
|
||||
# rowscale2 = None
|
||||
# else:
|
||||
# rowscale2 = self.drop_path2(torch.ones(
|
||||
# hidden_states.shape[:-1], device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype)
|
||||
# )
|
||||
rowscale2 = None
|
||||
if self.drop_path2.p == 0 or not self.training:
|
||||
rowscale2 = None
|
||||
else:
|
||||
rowscale2 = self.drop_path2(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm2.weight, self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user