[Block] Re-enable DropPath

This commit is contained in:
Tri Dao 2023-07-21 16:34:19 -07:00
parent 9ee0ff1d9b
commit 6fc1e07da2

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
# from torchvision.ops import StochasticDepth from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp from flash_attn.modules.mlp import Mlp
@ -70,12 +70,12 @@ class Block(nn.Module):
mlp_cls = partial(Mlp, hidden_features=4 * dim) mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim) self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1) self.dropout1 = dropout_cls(resid_dropout1)
# self.drop_path1 = StochasticDepth(drop_path1, mode='row') self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.norm1 = norm_cls(dim) self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim) self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity): if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2) self.dropout2 = dropout_cls(resid_dropout2)
# self.drop_path2 = StochasticDepth(drop_path2, mode='row') self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.norm2 = norm_cls(dim) self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
@ -129,14 +129,13 @@ class Block(nn.Module):
if self.residual_in_fp32: if self.residual_in_fp32:
residual = residual.to(torch.float32) residual = residual.to(torch.float32)
else: else:
rowscale1 = None if self.drop_path1.p == 0 or not self.training:
# if self.drop_path1.p == 0 or not self.training: rowscale1 = None
# rowscale1 = None else:
# else: rowscale1 = self.drop_path1(torch.ones(
# rowscale1 = self.drop_path1(torch.ones( hidden_states.shape[:-1], device=hidden_states.device,
# hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype)
# dtype=hidden_states.dtype) )
# )
hidden_states, residual = fused_add_norm_fn( hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias, hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps, self.dropout1.p if self.training else 0.0, self.norm1.eps,
@ -157,14 +156,13 @@ class Block(nn.Module):
if self.residual_in_fp32: if self.residual_in_fp32:
residual = residual.to(torch.float32) residual = residual.to(torch.float32)
else: else:
# if self.drop_path2.p == 0 or not self.training: if self.drop_path2.p == 0 or not self.training:
# rowscale2 = None rowscale2 = None
# else: else:
# rowscale2 = self.drop_path2(torch.ones( rowscale2 = self.drop_path2(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device, hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype) dtype=hidden_states.dtype)
# ) )
rowscale2 = None
hidden_states, residual = fused_add_norm_fn( hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias, hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps, self.dropout2.p if self.training else 0.0, self.norm2.eps,