Add tests for Pythia, GPT-JT, and RedPajama models
This commit is contained in:
parent
bb9beb3645
commit
d0032700d1
@ -352,9 +352,16 @@ class GPTPreTrainedModel(nn.Module):
|
||||
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
|
||||
elif model_name.startswith("facebook/opt"):
|
||||
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
||||
elif model_name.startswith("EleutherAI/gpt-j-"):
|
||||
elif (
|
||||
model_name.startswith("EleutherAI/gpt-j-")
|
||||
or model_name.startswith("togethercomputer/GPT-JT-")
|
||||
):
|
||||
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
||||
elif model_name.startswith("EleutherAI/gpt-neox-"):
|
||||
elif (
|
||||
model_name.startswith("EleutherAI/gpt-neox-")
|
||||
or model_name.startswith("EleutherAI/pythia-")
|
||||
or model_name.startswith("togethercomputer/RedPajama-INCITE-")
|
||||
):
|
||||
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
|
||||
elif model_name.startswith("tiiuae/falcon-"):
|
||||
state_dict = remap_state_dict_hf_falcon(state_dict, config)
|
||||
|
||||
@ -24,7 +24,15 @@ def test_gptj_state_dict(model_name):
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"EleutherAI/pythia-1b",
|
||||
"EleutherAI/pythia-2.8b",
|
||||
"EleutherAI/gpt-neox-20b",
|
||||
"togethercomputer/RedPajama-INCITE-7B-Base",
|
||||
],
|
||||
)
|
||||
def test_gpt_neox_optimized(model_name):
|
||||
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
@ -35,7 +43,12 @@ def test_gpt_neox_optimized(model_name):
|
||||
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
|
||||
config.fused_mlp = config.activation_function in [
|
||||
"gelu_fast",
|
||||
"gelu_new",
|
||||
"gelu_approx",
|
||||
"gelu_pytorch_tanh",
|
||||
]
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
@ -54,7 +67,7 @@ def test_gpt_neox_optimized(model_name):
|
||||
logits = model(input_ids).logits
|
||||
del model
|
||||
|
||||
# Need at least 2 GPUs, otherwise we'll OOM
|
||||
# Need at least 2 GPUs, otherwise we'll OOM for the 20B model
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
|
||||
model_ref.eval()
|
||||
|
||||
@ -23,7 +23,7 @@ def test_gptj_state_dict(model_name):
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
|
||||
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B", "togethercomputer/GPT-JT-6B-v1"])
|
||||
def test_gptj_optimized(model_name):
|
||||
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
|
||||
Loading…
Reference in New Issue
Block a user