Add tests for Pythia, GPT-JT, and RedPajama models

This commit is contained in:
Tri Dao 2023-09-13 01:03:30 -07:00
parent bb9beb3645
commit d0032700d1
3 changed files with 26 additions and 6 deletions

View File

@ -352,9 +352,16 @@ class GPTPreTrainedModel(nn.Module):
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
elif model_name.startswith("facebook/opt"):
state_dict = remap_state_dict_hf_opt(state_dict, config)
elif model_name.startswith("EleutherAI/gpt-j-"):
elif (
model_name.startswith("EleutherAI/gpt-j-")
or model_name.startswith("togethercomputer/GPT-JT-")
):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
elif model_name.startswith("EleutherAI/gpt-neox-"):
elif (
model_name.startswith("EleutherAI/gpt-neox-")
or model_name.startswith("EleutherAI/pythia-")
or model_name.startswith("togethercomputer/RedPajama-INCITE-")
):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
elif model_name.startswith("tiiuae/falcon-"):
state_dict = remap_state_dict_hf_falcon(state_dict, config)

View File

@ -24,7 +24,15 @@ def test_gptj_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
@pytest.mark.parametrize(
"model_name",
[
"EleutherAI/pythia-1b",
"EleutherAI/pythia-2.8b",
"EleutherAI/gpt-neox-20b",
"togethercomputer/RedPajama-INCITE-7B-Base",
],
)
def test_gpt_neox_optimized(model_name):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@ -35,7 +43,12 @@ def test_gpt_neox_optimized(model_name):
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
config.fused_mlp = config.activation_function in [
"gelu_fast",
"gelu_new",
"gelu_approx",
"gelu_pytorch_tanh",
]
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
@ -54,7 +67,7 @@ def test_gpt_neox_optimized(model_name):
logits = model(input_ids).logits
del model
# Need at least 2 GPUs, otherwise we'll OOM
# Need at least 2 GPUs, otherwise we'll OOM for the 20B model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
model_ref.eval()

View File

@ -23,7 +23,7 @@ def test_gptj_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B", "togethercomputer/GPT-JT-6B-v1"])
def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF