From 1feb94265c276dac0ba8c6cd32c436afca93d948 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 23 Nov 2022 12:48:56 -0800 Subject: [PATCH] [ViT] Use dropout_add_ln for the 1st layer norm --- flash_attn/models/gpt.py | 14 +++++++------- flash_attn/models/vit.py | 22 ++++++++++++++++++++-- flash_attn/modules/mlp.py | 1 - 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index c8cdafb..f8bd119 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -104,14 +104,14 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) -class GPT2Model(nn.Module): +class GPTModel(nn.Module): def __init__(self, config: GPT2Config): super().__init__() - self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False) - if self.pad_vocab_size_multiple_8: - if config.vocab_size % 8 != 0: - config.vocab_size += 8 - (config.vocab_size % 8) + self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + if config.vocab_size % self.pad_vocab_size_multiple != 0: + config.vocab_size += (self.pad_vocab_size_multiple + - (config.vocab_size % self.pad_vocab_size_multiple)) self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size, config.max_position_embeddings) @@ -153,11 +153,11 @@ class GPT2Model(nn.Module): return hidden_states -class GPT2LMHeadModel(nn.Module): +class GPTLMHeadModel(nn.Module): def __init__(self, config: GPT2Config): super().__init__() - self.transformer = GPT2Model(config) + self.transformer = GPTModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Initialize weights and apply final processing self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, diff --git a/flash_attn/models/vit.py b/flash_attn/models/vit.py index da354aa..ad39c4f 100644 --- a/flash_attn/models/vit.py +++ b/flash_attn/models/vit.py @@ -18,6 +18,11 @@ from flash_attn.modules.mha import MHA from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.block import Block +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False): @@ -152,6 +157,10 @@ class VisionTransformer(nn.Module): # (in the pretrained weight) is the final layer norm. self.norm_0 = norm_layer(embed_dim) + self.fused_dropout_add_ln = fused_dropout_add_ln + if self.fused_dropout_add_ln and dropout_add_layer_norm is None: + raise ImportError('dropout_add_layer_norm is not installed') + self.blocks = nn.ModuleList([create_block( embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, @@ -193,7 +202,7 @@ class VisionTransformer(nn.Module): if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.pos_embed - return self.pos_drop(x) + return x def forward_features(self, x, all_tokens=True): """ @@ -201,8 +210,17 @@ class VisionTransformer(nn.Module): cls token. """ x = self.patch_embed(x) + x = self._pos_embed(x) # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed - residual = self._pos_embed(x).float() + if not self.fused_dropout_add_ln: + residual = self.pos_drop(x).float() + hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype)) + else: + hidden_states, residual = dropout_add_layer_norm( + x, None, self.norm_0.weight, self.norm_0.bias, + self.pos_drop.p if self.training else 0.0, self.norm_0.eps, prenorm=True, + residual_in_fp32=True + ) hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype)) if self.global_pool != 'token' or all_tokens: for block in self.blocks: diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index c7916ed..f3e749a 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -64,7 +64,6 @@ class FusedDenseGeluDense(nn.Module): self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) def forward(self, x): - assert x.dtype in [torch.float16, torch.bfloat16] assert x.is_cuda fn = (fused_dense_gelu_dense_function_td if not self.return_residual else fused_dense_res_gelu_dense_function_td)