From 07005806ffba62f4e32c0bec6da0f923eef77401 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Sun, 10 Sep 2023 17:24:50 -0700 Subject: [PATCH] Add BigCode converters (#532) --- flash_attn/models/bert.py | 29 +++-- flash_attn/models/bigcode.py | 233 +++++++++++++++++++++++++++++++++++ flash_attn/models/gpt.py | 41 +++--- tests/models/test_bert.py | 15 ++- tests/models/test_bigcode.py | 206 +++++++++++++++++++++++++++++++ 5 files changed, 496 insertions(+), 28 deletions(-) create mode 100644 flash_attn/models/bigcode.py create mode 100644 tests/models/test_bigcode.py diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index b6136dc..4aaafdf 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -18,11 +18,16 @@ import torch.nn.functional as F from einops import rearrange from transformers import BertConfig, PretrainedConfig from transformers.models.bert.modeling_bert import ( - BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput) + BaseModelOutputWithPoolingAndCrossAttentions, + BertForPreTrainingOutput, +) -from flash_attn.bert_padding import (index_first_axis, - index_first_axis_residual, pad_input, - unpad_input) +from flash_attn.bert_padding import ( + index_first_axis, + index_first_axis_residual, + pad_input, + unpad_input, +) from flash_attn.modules.block import Block from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.mha import MHA @@ -75,11 +80,15 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False): inner_dim = config.intermediate_size fused_mlp = getattr(config, "fused_mlp", False) if fused_mlp: - assert config.hidden_act in ["gelu_new", "gelu_fast"], ( + assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], ( "fused_mlp only " "supports approximate gelu" ) if not fused_mlp: - approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none" + approximate = ( + "tanh" + if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] + else "none" + ) mlp_cls = partial( Mlp, hidden_features=inner_dim, @@ -232,7 +241,11 @@ class BertPredictionHeadTransform(nn.Module): raise ImportError("dropout_add_layer_norm is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense self.dense = linear_cls(config.hidden_size, config.hidden_size) - approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none" + approximate = ( + "tanh" + if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] + else "none" + ) self.transform_act_fn = nn.GELU(approximate=approximate) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -334,7 +347,7 @@ class BertModel(BertPreTrainedModel): self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) if self.fused_dropout_add_ln and layer_norm is None: raise ImportError("dropout_add_layer_norm is not installed") - assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast"] + assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"] self.embeddings = BertEmbeddings( config.hidden_size, diff --git a/flash_attn/models/bigcode.py b/flash_attn/models/bigcode.py new file mode 100644 index 0000000..234944d --- /dev/null +++ b/flash_attn/models/bigcode.py @@ -0,0 +1,233 @@ +import math +import re +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig + + +def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a Huggingface BigCode model to be flash_attn compatible. + """ + + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) + + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.wte.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub( + r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", + r"transformer.layers.\1.norm\2.\3", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + def key_mapping_mlp(key): + key = re.sub( + r"^transformer.h.(\d+).mlp.c_fc.weight", + r"transformer.layers.\1.mlp.fc1.weight", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).mlp.c_proj.weight", + r"transformer.layers.\1.mlp.fc2.weight", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).mlp.c_fc.bias", + r"transformer.layers.\1.mlp.fc1.bias", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).mlp.c_proj.bias", + r"transformer.layers.\1.mlp.fc2.bias", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # TODO: add support for multi-head attention + assert config.multi_query, "Only multi-query attention is supported" + + # Attention + for d in range(config.num_hidden_layers): + embed_dim = config.n_embd + head_dim = embed_dim // config.n_head + + c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") + # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim) + # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112 + # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183 + # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) + q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0) + # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) + k = torch.tile(k, (config.n_head, 1)) + v = torch.tile(v, (config.n_head, 1)) + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0) + + # same deal with the bias + c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias") + # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) + q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0) + # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) + k = torch.tile(k, (config.n_head,)) + v = torch.tile(v, (config.n_head,)) + state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0) + + def key_mapping_attn(key): + key = re.sub( + r"^transformer.h.(\d+).attn.c_proj.weight", + r"transformer.layers.\1.mixer.out_proj.weight", + key, + ) + key = re.sub( + r"^transformer.h.(\d+).attn.c_proj.bias", + r"transformer.layers.\1.mixer.out_proj.bias", + key, + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): + """ + Map the state_dict of a flash_attn model to be Huggingface BigCode compatible. + + This function is meant to be the inverse of remap_state_dict_hf_bigcode. + """ + + # Word embedding and position embeddings + def inv_key_mapping_pos_emb(key): + return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key) + + state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") + + word_embeddings = word_embeddings[:, : config.vocab_size] + state_dict["transformer.wte.weight"] = word_embeddings + state_dict["lm_head.weight"] = word_embeddings + + # LayerNorm + def inv_key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub( + r"^transformer.layers.(\d+).norm(1|2).(weight|bias)", + r"transformer.h.\1.ln_\2.\3", + key, + ) + return key + + state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLPs + def inv_key_mapping_mlp(key): + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc1.weight", + r"transformer.h.\1.mlp.c_fc.weight", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc2.weight", + r"transformer.h.\1.mlp.c_proj.weight", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc1.bias", + r"transformer.h.\1.mlp.c_fc.bias", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.fc2.bias", + r"transformer.h.\1.mlp.c_proj.bias", + key, + ) + return key + + state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for d in range(config.num_hidden_layers): + embed_dim = config.n_embd + head_dim = embed_dim // config.n_head + + Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") + q, k, v = torch.split( + Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 + ) + c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) + state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight + + # Same deal with the bias + Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") + q, k, v = torch.split( + Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 + ) + c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) + state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias + + def inv_key_mapping_attn(key): + key = re.sub( + r"^transformer.layers.(\d+).mixer.out_proj.weight", + r"transformer.h.\1.attn.c_proj.weight", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mixer.out_proj.bias", + r"transformer.h.\1.attn.c_proj.bias", + key, + ) + return key + + state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config: + return GPT2Config( + activation_function=bigcode_config.activation_function, + attn_pdrop=bigcode_config.attn_pdrop, + bos_token_id=bigcode_config.bos_token_id, + embd_pdrop=bigcode_config.embd_pdrop, + eos_token_id=bigcode_config.eos_token_id, + initializer_range=bigcode_config.initializer_range, + layer_norm_epsilon=bigcode_config.layer_norm_epsilon, + max_batch_size=bigcode_config.max_batch_size, + max_sequence_length=bigcode_config.max_sequence_length, + model_type=bigcode_config.model_type, + multi_query=bigcode_config.multi_query, + n_embd=bigcode_config.n_embd, + n_head=bigcode_config.n_head, + n_inner=bigcode_config.n_inner, + n_layer=bigcode_config.n_layer, + n_positions=bigcode_config.n_positions, + resid_pdrop=bigcode_config.resid_pdrop, + scale_attn_weights=bigcode_config.scale_attn_weights, + summary_activation=bigcode_config.summary_activation, + summary_first_dropout=bigcode_config.summary_first_dropout, + summary_proj_to_labels=bigcode_config.summary_proj_to_labels, + summary_type=bigcode_config.summary_type, + summary_use_proj=bigcode_config.summary_use_proj, + use_cache=bigcode_config.use_cache, + vocab_size=bigcode_config.vocab_size, + ) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index b8da95f..f2ae955 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from einops import rearrange from transformers import GPT2Config +from flash_attn.models.bigcode import remap_state_dict_hf_bigcode from flash_attn.models.falcon import remap_state_dict_hf_falcon from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox from flash_attn.models.gptj import remap_state_dict_hf_gptj @@ -21,12 +22,16 @@ from flash_attn.models.opt import remap_state_dict_hf_opt from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP, - ParallelGatedMlp, ParallelMLP) +from flash_attn.modules.mlp import ( + FusedMLP, + GatedMlp, + Mlp, + ParallelFusedMLP, + ParallelGatedMlp, + ParallelMLP, +) from flash_attn.ops.activations import sqrelu_fwd -from flash_attn.utils.distributed import (all_gather_raw, - get_dim_for_local_rank, - sync_shared_params) +from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.pretrained import state_dict_from_pretrained @@ -41,8 +46,7 @@ except ImportError: dropout_add_layer_norm = None try: - from flash_attn.ops.layer_norm import \ - dropout_add_layer_norm_parallel_residual + from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual except ImportError: dropout_add_layer_norm_parallel_residual = None @@ -129,6 +133,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp "gelu_new", "gelu_fast", "gelu_approx", + "gelu_pytorch_tanh", "relu", "sqrelu", ] @@ -144,6 +149,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp "gelu_new", "gelu_fast", "gelu_approx", + "gelu_pytorch_tanh", "relu", "sqrelu", "glu", @@ -182,7 +188,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp else: approximate = ( "tanh" - if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] + if config.activation_function + in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] else "none" ) activation = partial(F.gelu, approximate=approximate) @@ -215,7 +222,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp raise ImportError("fused_dense is not installed") activation = ( "gelu_approx" - if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] + if config.activation_function + in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] else config.activation_function ) mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP @@ -352,6 +360,8 @@ class GPTPreTrainedModel(nn.Module): state_dict = remap_state_dict_hf_falcon(state_dict, config) elif model_name.startswith("meta-llama/Llama-"): state_dict = remap_state_dict_hf_llama(state_dict, config) + elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"): + state_dict = remap_state_dict_hf_bigcode(state_dict, config) else: raise NotImplementedError(f"Model {model_name} not supported") if world_size > 1: @@ -394,6 +404,7 @@ class GPTModel(GPTPreTrainedModel): "gelu_new", "gelu_fast", "gelu_approx", + "gelu_pytorch_tanh", "relu", "sqrelu", "glu", @@ -628,7 +639,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 num_last_tokens: if > 0, only return the logits for the last n tokens """ - assert input_ids.ndim == 2, f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" + assert ( + input_ids.ndim == 2 + ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" b, slen = input_ids.shape hidden_states = self.transformer( input_ids, position_ids=position_ids, inference_params=inference_params @@ -845,11 +858,11 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G nheadqkv=rank_n_head + 2 * rank_n_head_kv, headdim=headdim, ) - for s, rank_n_head, rank_n_head_kv in zip(state_dicts, n_head_each_rank, n_head_kv_each_rank) + for s, rank_n_head, rank_n_head_kv in zip( + state_dicts, n_head_each_rank, n_head_kv_each_rank + ) ] - wq = torch.cat( - [x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0 - ) + wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0) wk = torch.cat( [ x[ diff --git a/tests/models/test_bert.py b/tests/models/test_bert.py index c3d2ddb..4c519b3 100644 --- a/tests/models/test_bert.py +++ b/tests/models/test_bert.py @@ -6,12 +6,15 @@ import torch import torch.nn.functional as F from einops import rearrange from transformers import BertConfig -from transformers.models.bert.modeling_bert import \ - BertForPreTraining as BertForPreTrainingHF +from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF from transformers.models.bert.modeling_bert import BertModel as BertModelHF -from flash_attn.models.bert import (BertForPreTraining, BertModel, - inv_remap_state_dict, remap_state_dict) +from flash_attn.models.bert import ( + BertForPreTraining, + BertModel, + inv_remap_state_dict, + remap_state_dict, +) from flash_attn.utils.pretrained import state_dict_from_pretrained @@ -102,7 +105,7 @@ def test_bert_optimized(model_name): dtype = torch.float16 config = BertConfig.from_pretrained(model_name) # Our implementation of fused_mlp assumes the activation is - # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast". + # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh". # If you just want "gelu", disable fused_mlp. config.hidden_act = "gelu_new" config.use_flash_attn = True @@ -209,7 +212,7 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs dtype = torch.float16 config = BertConfig.from_pretrained(model_name) # Our implementation of fused_mlp assumes the activation is - # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast". + # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh". # If you just want "gelu", disable fused_mlp. config.hidden_act = "gelu_new" config.use_flash_attn = True diff --git a/tests/models/test_bigcode.py b/tests/models/test_bigcode.py new file mode 100644 index 0000000..63da387 --- /dev/null +++ b/tests/models/test_bigcode.py @@ -0,0 +1,206 @@ +import time + +import pytest +import torch +from transformers import AutoTokenizer, GPTBigCodeConfig +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM + +from flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode +from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode +from flash_attn.utils.generation import update_graph_cache +from flash_attn.utils.pretrained import state_dict_from_pretrained + + +@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) +def test_bigcode_state_dict(model_name): + config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name)) + pretrained_state_dict = remap_state_dict_hf_bigcode( + state_dict_from_pretrained(model_name), config + ) + model = GPTLMHeadModel(config, device="meta") + state_dict = model.state_dict() + assert state_dict.keys() == pretrained_state_dict.keys() + for k in state_dict.keys(): + assert state_dict[k].shape == pretrained_state_dict[k].shape + + +@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) +def test_bigcode_optimized(model_name): + """Check that our implementation of BigCode (with all optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = "cuda" + config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name)) + config.use_flash_attn = True # FlashAttention-2 supports headdim 256 + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + batch_size = 2 + max_seqlen = 256 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device + ) + with torch.no_grad(): + out = model.transformer(input_ids) + logits = model(input_ids).logits + del model + + # Without device_map, the model is loaded on the CPU, which is very slow + model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device}) + model_ref.eval() + with torch.no_grad(): + out_ref = model_ref.transformer(input_ids).last_hidden_state + logits_ref = model_ref(input_ids).logits + del model_ref + + model_hf = GPTBigCodeForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map={"": device} + ) + model_hf.eval() + out_hf = model_hf.transformer(input_ids).last_hidden_state + logits_hf = model_hf(input_ids).logits + del model_hf + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") + assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() + + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") + assert (logits - logits_ref).abs().max().item() < 3 * ( + logits_hf - logits_ref + ).abs().max().item() + + +@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) +def test_bigcode_generation(model_name): + """Check that our implementation of BigCode (with all optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = "cuda" + config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name)) + config.use_flash_attn = True # FlashAttention-2 supports headdim 256 + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = True + + tokenizer = AutoTokenizer.from_pretrained(model_name) + eos_token_id = tokenizer.eos_token_id + + torch.manual_seed(0) + batch_size = 1 + seqlen = 100 + max_length = 150 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device + ) + + model_hf = GPTBigCodeForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map={"": device} + ) + model_hf.eval() + print("HF fp16") + torch.cuda.synchronize() + start = time.time() + out_hf = model_hf.generate( + input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + del model_hf + + model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device}) + model_ref.eval() + with torch.no_grad(): + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] + del model_ref + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + print("Without CUDA graph") + torch.cuda.synchronize() + start = time.time() + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + fused_ft_kernel=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=True, + teacher_outputs=out_hf.sequences, + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + + # Capture graph outside the timing loop + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + print("With CUDA graph") + torch.cuda.synchronize() + start = time.time() + out_cg = model.generate( + input_ids=input_ids, + max_length=max_length, + fused_ft_kernel=True, + cg=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=True, + teacher_outputs=out_hf.sequences, + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + + with torch.no_grad(): + logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] + logits_hf = torch.stack(out_hf.scores, dim=1) + logits = torch.stack(out.scores, dim=1) + logits_cg = torch.stack(out_cg.scores, dim=1) + + del model + + hf_error = (logits_hf - logits_ref).abs().max().item() + assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error + + print(f"HF fp16 logits max diff: {hf_error}") + print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") + assert (logits - logits_ref).abs().max().item() < 2 * hf_error + print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") + assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error + + +@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) +def test_inv_remap_state_dict(model_name: str): + """ + Verify that we can convert a HF BigCode model to flash_attn and back. + """ + + state_dict = state_dict_from_pretrained(model_name) + config = GPTBigCodeConfig.from_pretrained(model_name) + + flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config) + recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config) + + assert set(state_dict.keys()) == set(recovered_state_dict.keys()) + + for k in state_dict.keys(): + assert state_dict[k].shape == recovered_state_dict[k].shape + torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)