[ViT] Use dropout_add_ln for the 1st layer norm
This commit is contained in:
parent
45bcf37b97
commit
1feb94265c
@ -104,14 +104,14 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
|
||||
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
class GPTModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False)
|
||||
if self.pad_vocab_size_multiple_8:
|
||||
if config.vocab_size % 8 != 0:
|
||||
config.vocab_size += 8 - (config.vocab_size % 8)
|
||||
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
config.vocab_size += (self.pad_vocab_size_multiple
|
||||
- (config.vocab_size % self.pad_vocab_size_multiple))
|
||||
|
||||
self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
|
||||
config.max_position_embeddings)
|
||||
@ -153,11 +153,11 @@ class GPT2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2LMHeadModel(nn.Module):
|
||||
class GPTLMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.transformer = GPT2Model(config)
|
||||
self.transformer = GPTModel(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
|
||||
@ -18,6 +18,11 @@ from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
|
||||
from flash_attn.modules.block import Block
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
|
||||
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=False):
|
||||
@ -152,6 +157,10 @@ class VisionTransformer(nn.Module):
|
||||
# (in the pretrained weight) is the final layer norm.
|
||||
self.norm_0 = norm_layer(embed_dim)
|
||||
|
||||
self.fused_dropout_add_ln = fused_dropout_add_ln
|
||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
|
||||
self.blocks = nn.ModuleList([create_block(
|
||||
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i],
|
||||
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
|
||||
@ -193,7 +202,7 @@ class VisionTransformer(nn.Module):
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
return self.pos_drop(x)
|
||||
return x
|
||||
|
||||
def forward_features(self, x, all_tokens=True):
|
||||
"""
|
||||
@ -201,8 +210,17 @@ class VisionTransformer(nn.Module):
|
||||
cls token.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
|
||||
residual = self._pos_embed(x).float()
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.pos_drop(x).float()
|
||||
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
|
||||
else:
|
||||
hidden_states, residual = dropout_add_layer_norm(
|
||||
x, None, self.norm_0.weight, self.norm_0.bias,
|
||||
self.pos_drop.p if self.training else 0.0, self.norm_0.eps, prenorm=True,
|
||||
residual_in_fp32=True
|
||||
)
|
||||
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
|
||||
if self.global_pool != 'token' or all_tokens:
|
||||
for block in self.blocks:
|
||||
|
||||
@ -64,7 +64,6 @@ class FusedDenseGeluDense(nn.Module):
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dtype in [torch.float16, torch.bfloat16]
|
||||
assert x.is_cuda
|
||||
fn = (fused_dense_gelu_dense_function_td if not self.return_residual
|
||||
else fused_dense_res_gelu_dense_function_td)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user