Inverse state dict for BERT (#527)
This commit is contained in:
parent
a86442f0f3
commit
4c91621a5e
3
.gitignore
vendored
3
.gitignore
vendored
@ -22,3 +22,6 @@ var/
|
||||
|
||||
# IDE-related
|
||||
.idea/
|
||||
|
||||
# Dev
|
||||
venv
|
||||
@ -10,23 +10,19 @@ import re
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import Any, Mapping
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BertConfig
|
||||
from transformers import BertConfig, PretrainedConfig
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BertForPreTrainingOutput,
|
||||
)
|
||||
BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
|
||||
|
||||
from flash_attn.bert_padding import (
|
||||
index_first_axis,
|
||||
index_first_axis_residual,
|
||||
pad_input,
|
||||
unpad_input,
|
||||
)
|
||||
from flash_attn.bert_padding import (index_first_axis,
|
||||
index_first_axis_residual, pad_input,
|
||||
unpad_input)
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import BertEmbeddings
|
||||
from flash_attn.modules.mha import MHA
|
||||
@ -511,7 +507,11 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
def remap_state_dict(state_dict, config):
|
||||
def remap_state_dict(state_dict, config: PretrainedConfig):
|
||||
"""
|
||||
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
||||
"""
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln_gamma_beta(key):
|
||||
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
||||
@ -618,3 +618,133 @@ def remap_state_dict(state_dict, config):
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
||||
"""
|
||||
Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
|
||||
|
||||
This function is meant to be the inverse of remap_state_dict.
|
||||
"""
|
||||
# Word embedding
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
if pad_vocab_size_multiple > 1:
|
||||
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
||||
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
||||
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
||||
# unpad embeddings
|
||||
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
||||
: config.orig_vocab_size, :
|
||||
]
|
||||
state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
|
||||
state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
|
||||
|
||||
for d in range(config.num_hidden_layers):
|
||||
last_layer_subset = getattr(config, "last_layer_subset", False)
|
||||
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
||||
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
||||
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
|
||||
: Wqkv_weights.shape[0] // 3, :
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
|
||||
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
|
||||
2 * Wqkv_weights.shape[0] // 3 :, :
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
|
||||
: Wqkv_biases.shape[0] // 3
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
|
||||
Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
|
||||
2 * Wqkv_biases.shape[0] // 3 :
|
||||
]
|
||||
else:
|
||||
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
||||
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
||||
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
||||
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
|
||||
: Wkv_weights.shape[0] // 2, :
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
|
||||
Wkv_weights.shape[0] // 2 :, :
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
||||
: Wkv_biases.shape[0] // 2
|
||||
]
|
||||
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
|
||||
Wkv_biases.shape[0] // 2 :
|
||||
]
|
||||
|
||||
def inv_key_mapping_ln(key):
|
||||
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
||||
key = re.sub(
|
||||
r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
|
||||
r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
|
||||
r"bert.encoder.layers.\1.output.LayerNorm.\2",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"cls.predictions.transform.layer_norm.(weight|bias)",
|
||||
r"cls.predictions.transform.LayerNorm.\1",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
def inv_key_mapping_ln_gamma_beta(key):
|
||||
key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
|
||||
key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
|
||||
return key
|
||||
|
||||
def inv_key_mapping_layers(key):
|
||||
return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
|
||||
|
||||
def inv_key_mapping_mlp(key):
|
||||
key = re.sub(
|
||||
r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
|
||||
r"bert.encoder.layer.\1.intermediate.dense.\2",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
|
||||
r"bert.encoder.layer.\1.output.dense.\2",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
def inv_key_mapping_attn(key):
|
||||
return re.sub(
|
||||
r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
|
||||
r"bert.encoder.layer.\1.attention.output.dense.\2",
|
||||
key,
|
||||
)
|
||||
|
||||
def inv_key_mapping_decoder_bias(key):
|
||||
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
||||
|
||||
state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
|
||||
state_dict = OrderedDict(
|
||||
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
||||
)
|
||||
state_dict = OrderedDict(
|
||||
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
||||
)
|
||||
state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
|
||||
state_dict = OrderedDict(
|
||||
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
||||
)
|
||||
state_dict = OrderedDict(
|
||||
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
@ -5,12 +5,15 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.models.bert import BertForPreTraining, BertModel, remap_state_dict
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from transformers import BertConfig
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
|
||||
from transformers.models.bert.modeling_bert import \
|
||||
BertForPreTraining as BertForPreTrainingHF
|
||||
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
|
||||
|
||||
from flash_attn.models.bert import (BertForPreTraining, BertModel,
|
||||
inv_remap_state_dict, remap_state_dict)
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
|
||||
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
|
||||
@ -43,7 +46,7 @@ def get_hf_models(model_name, config, dtype):
|
||||
return model_hf
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["bert-base-uncased"])
|
||||
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
|
||||
def test_bert_non_optimized(model_name):
|
||||
"""Check that our implementation of BERT (without any optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
@ -297,3 +300,22 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
|
||||
).abs().max().item()
|
||||
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
|
||||
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
|
||||
def test_inv_remap_state_dict(model_name: str):
|
||||
"""
|
||||
Verify that we can convert a HF BERT model to flash_attn and back.
|
||||
"""
|
||||
|
||||
state_dict = state_dict_from_pretrained(model_name)
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
|
||||
flash_state_dict = remap_state_dict(state_dict, config)
|
||||
recovered_state_dict = inv_remap_state_dict(flash_state_dict, config)
|
||||
|
||||
assert set(state_dict.keys()) == set(recovered_state_dict.keys())
|
||||
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == recovered_state_dict[k].shape
|
||||
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user