Inverse state dict for BERT (#527)

This commit is contained in:
Kevin Hu 2023-09-09 01:44:21 -07:00 committed by GitHub
parent a86442f0f3
commit 4c91621a5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 15 deletions

3
.gitignore vendored
View File

@ -22,3 +22,6 @@ var/
# IDE-related # IDE-related
.idea/ .idea/
# Dev
venv

View File

@ -10,23 +10,19 @@ import re
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from typing import Any, Mapping
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BertConfig from transformers import BertConfig, PretrainedConfig
from transformers.models.bert.modeling_bert import ( from transformers.models.bert.modeling_bert import (
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
BertForPreTrainingOutput,
)
from flash_attn.bert_padding import ( from flash_attn.bert_padding import (index_first_axis,
index_first_axis, index_first_axis_residual, pad_input,
index_first_axis_residual, unpad_input)
pad_input,
unpad_input,
)
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA
@ -511,7 +507,11 @@ class BertForPreTraining(BertPreTrainedModel):
) )
def remap_state_dict(state_dict, config): def remap_state_dict(state_dict, config: PretrainedConfig):
"""
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
"""
# LayerNorm # LayerNorm
def key_mapping_ln_gamma_beta(key): def key_mapping_ln_gamma_beta(key):
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
@ -618,3 +618,133 @@ def remap_state_dict(state_dict, config):
) )
return state_dict return state_dict
def inv_remap_state_dict(state_dict, config: PretrainedConfig):
"""
Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
This function is meant to be the inverse of remap_state_dict.
"""
# Word embedding
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
if pad_vocab_size_multiple > 1:
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
decoder_weight = state_dict["cls.predictions.decoder.weight"]
decoder_bias = state_dict["cls.predictions.decoder.bias"]
# unpad embeddings
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
: config.orig_vocab_size, :
]
state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
for d in range(config.num_hidden_layers):
last_layer_subset = getattr(config, "last_layer_subset", False)
if not last_layer_subset or d != (config.num_hidden_layers - 1):
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
: Wqkv_weights.shape[0] // 3, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
2 * Wqkv_weights.shape[0] // 3 :, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
: Wqkv_biases.shape[0] // 3
]
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
2 * Wqkv_biases.shape[0] // 3 :
]
else:
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
: Wkv_weights.shape[0] // 2, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
Wkv_weights.shape[0] // 2 :, :
]
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
: Wkv_biases.shape[0] // 2
]
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
Wkv_biases.shape[0] // 2 :
]
def inv_key_mapping_ln(key):
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
key = re.sub(
r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
key,
)
key = re.sub(
r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
r"bert.encoder.layers.\1.output.LayerNorm.\2",
key,
)
key = re.sub(
r"cls.predictions.transform.layer_norm.(weight|bias)",
r"cls.predictions.transform.LayerNorm.\1",
key,
)
return key
def inv_key_mapping_ln_gamma_beta(key):
key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
return key
def inv_key_mapping_layers(key):
return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
def inv_key_mapping_mlp(key):
key = re.sub(
r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
r"bert.encoder.layer.\1.intermediate.dense.\2",
key,
)
key = re.sub(
r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
r"bert.encoder.layer.\1.output.dense.\2",
key,
)
return key
def inv_key_mapping_attn(key):
return re.sub(
r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
r"bert.encoder.layer.\1.attention.output.dense.\2",
key,
)
def inv_key_mapping_decoder_bias(key):
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
state_dict = OrderedDict(
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
)
state_dict = OrderedDict(
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
)
state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
state_dict = OrderedDict(
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
)
state_dict = OrderedDict(
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
)
return state_dict

View File

@ -5,12 +5,15 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import BertConfig from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF from transformers.models.bert.modeling_bert import \
BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import BertModel as BertModelHF from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from flash_attn.models.bert import (BertForPreTraining, BertModel,
inv_remap_state_dict, remap_state_dict)
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
@ -43,7 +46,7 @@ def get_hf_models(model_name, config, dtype):
return model_hf return model_hf
@pytest.mark.parametrize('model_name', ["bert-base-uncased"]) @pytest.mark.parametrize("model_name", ["bert-base-uncased"])
def test_bert_non_optimized(model_name): def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the """Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@ -297,3 +300,22 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
).abs().max().item() ).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0. # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item() # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
def test_inv_remap_state_dict(model_name: str):
"""
Verify that we can convert a HF BERT model to flash_attn and back.
"""
state_dict = state_dict_from_pretrained(model_name)
config = BertConfig.from_pretrained(model_name)
flash_state_dict = remap_state_dict(state_dict, config)
recovered_state_dict = inv_remap_state_dict(flash_state_dict, config)
assert set(state_dict.keys()) == set(recovered_state_dict.keys())
for k in state_dict.keys():
assert state_dict[k].shape == recovered_state_dict[k].shape
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)