[Block] Re-enable DropPath

This commit is contained in:
Tri Dao 2023-07-21 16:34:19 -07:00
parent 9ee0ff1d9b
commit 6fc1e07da2

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# from torchvision.ops import StochasticDepth
from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
@ -70,12 +70,12 @@ class Block(nn.Module):
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
# self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
# self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
@ -129,14 +129,13 @@ class Block(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
rowscale1 = None
# if self.drop_path1.p == 0 or not self.training:
# rowscale1 = None
# else:
# rowscale1 = self.drop_path1(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype)
# )
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
@ -157,14 +156,13 @@ class Block(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
# if self.drop_path2.p == 0 or not self.training:
# rowscale2 = None
# else:
# rowscale2 = self.drop_path2(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype)
# )
rowscale2 = None
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,