From 6fc1e07da22a344b6f0927b9e21e0eafb31fda99 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 21 Jul 2023 16:34:19 -0700 Subject: [PATCH] [Block] Re-enable DropPath --- flash_attn/modules/block.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index 5a25d79..a4ff5a2 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -# from torchvision.ops import StochasticDepth +from torchvision.ops import StochasticDepth from flash_attn.modules.mha import MHA from flash_attn.modules.mlp import Mlp @@ -70,12 +70,12 @@ class Block(nn.Module): mlp_cls = partial(Mlp, hidden_features=4 * dim) self.mixer = mixer_cls(dim) self.dropout1 = dropout_cls(resid_dropout1) - # self.drop_path1 = StochasticDepth(drop_path1, mode='row') + self.drop_path1 = StochasticDepth(drop_path1, mode='row') self.norm1 = norm_cls(dim) self.mlp = mlp_cls(dim) if not isinstance(self.mlp, nn.Identity): self.dropout2 = dropout_cls(resid_dropout2) - # self.drop_path2 = StochasticDepth(drop_path2, mode='row') + self.drop_path2 = StochasticDepth(drop_path2, mode='row') self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: @@ -129,14 +129,13 @@ class Block(nn.Module): if self.residual_in_fp32: residual = residual.to(torch.float32) else: - rowscale1 = None - # if self.drop_path1.p == 0 or not self.training: - # rowscale1 = None - # else: - # rowscale1 = self.drop_path1(torch.ones( - # hidden_states.shape[:-1], device=hidden_states.device, - # dtype=hidden_states.dtype) - # ) + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1(torch.ones( + hidden_states.shape[:-1], device=hidden_states.device, + dtype=hidden_states.dtype) + ) hidden_states, residual = fused_add_norm_fn( hidden_states, residual, self.norm1.weight, self.norm1.bias, self.dropout1.p if self.training else 0.0, self.norm1.eps, @@ -157,14 +156,13 @@ class Block(nn.Module): if self.residual_in_fp32: residual = residual.to(torch.float32) else: - # if self.drop_path2.p == 0 or not self.training: - # rowscale2 = None - # else: - # rowscale2 = self.drop_path2(torch.ones( - # hidden_states.shape[:-1], device=hidden_states.device, - # dtype=hidden_states.dtype) - # ) - rowscale2 = None + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2(torch.ones( + hidden_states.shape[:-1], device=hidden_states.device, + dtype=hidden_states.dtype) + ) hidden_states, residual = fused_add_norm_fn( hidden_states, residual, self.norm2.weight, self.norm2.bias, self.dropout2.p if self.training else 0.0, self.norm2.eps,