[Block] Re-enable DropPath
This commit is contained in:
parent
9ee0ff1d9b
commit
6fc1e07da2
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
# from torchvision.ops import StochasticDepth
|
from torchvision.ops import StochasticDepth
|
||||||
|
|
||||||
from flash_attn.modules.mha import MHA
|
from flash_attn.modules.mha import MHA
|
||||||
from flash_attn.modules.mlp import Mlp
|
from flash_attn.modules.mlp import Mlp
|
||||||
@ -70,12 +70,12 @@ class Block(nn.Module):
|
|||||||
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
||||||
self.mixer = mixer_cls(dim)
|
self.mixer = mixer_cls(dim)
|
||||||
self.dropout1 = dropout_cls(resid_dropout1)
|
self.dropout1 = dropout_cls(resid_dropout1)
|
||||||
# self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||||
self.norm1 = norm_cls(dim)
|
self.norm1 = norm_cls(dim)
|
||||||
self.mlp = mlp_cls(dim)
|
self.mlp = mlp_cls(dim)
|
||||||
if not isinstance(self.mlp, nn.Identity):
|
if not isinstance(self.mlp, nn.Identity):
|
||||||
self.dropout2 = dropout_cls(resid_dropout2)
|
self.dropout2 = dropout_cls(resid_dropout2)
|
||||||
# self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||||
self.norm2 = norm_cls(dim)
|
self.norm2 = norm_cls(dim)
|
||||||
|
|
||||||
if self.fused_dropout_add_ln:
|
if self.fused_dropout_add_ln:
|
||||||
@ -129,14 +129,13 @@ class Block(nn.Module):
|
|||||||
if self.residual_in_fp32:
|
if self.residual_in_fp32:
|
||||||
residual = residual.to(torch.float32)
|
residual = residual.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
rowscale1 = None
|
if self.drop_path1.p == 0 or not self.training:
|
||||||
# if self.drop_path1.p == 0 or not self.training:
|
rowscale1 = None
|
||||||
# rowscale1 = None
|
else:
|
||||||
# else:
|
rowscale1 = self.drop_path1(torch.ones(
|
||||||
# rowscale1 = self.drop_path1(torch.ones(
|
hidden_states.shape[:-1], device=hidden_states.device,
|
||||||
# hidden_states.shape[:-1], device=hidden_states.device,
|
dtype=hidden_states.dtype)
|
||||||
# dtype=hidden_states.dtype)
|
)
|
||||||
# )
|
|
||||||
hidden_states, residual = fused_add_norm_fn(
|
hidden_states, residual = fused_add_norm_fn(
|
||||||
hidden_states, residual, self.norm1.weight, self.norm1.bias,
|
hidden_states, residual, self.norm1.weight, self.norm1.bias,
|
||||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||||
@ -157,14 +156,13 @@ class Block(nn.Module):
|
|||||||
if self.residual_in_fp32:
|
if self.residual_in_fp32:
|
||||||
residual = residual.to(torch.float32)
|
residual = residual.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
# if self.drop_path2.p == 0 or not self.training:
|
if self.drop_path2.p == 0 or not self.training:
|
||||||
# rowscale2 = None
|
rowscale2 = None
|
||||||
# else:
|
else:
|
||||||
# rowscale2 = self.drop_path2(torch.ones(
|
rowscale2 = self.drop_path2(torch.ones(
|
||||||
# hidden_states.shape[:-1], device=hidden_states.device,
|
hidden_states.shape[:-1], device=hidden_states.device,
|
||||||
# dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
# )
|
)
|
||||||
rowscale2 = None
|
|
||||||
hidden_states, residual = fused_add_norm_fn(
|
hidden_states, residual = fused_add_norm_fn(
|
||||||
hidden_states, residual, self.norm2.weight, self.norm2.bias,
|
hidden_states, residual, self.norm2.weight, self.norm2.bias,
|
||||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user