diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index f2ae955..e822028 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -352,9 +352,16 @@ class GPTPreTrainedModel(nn.Module): state_dict = remap_state_dict_hf_gpt2(state_dict, config) elif model_name.startswith("facebook/opt"): state_dict = remap_state_dict_hf_opt(state_dict, config) - elif model_name.startswith("EleutherAI/gpt-j-"): + elif ( + model_name.startswith("EleutherAI/gpt-j-") + or model_name.startswith("togethercomputer/GPT-JT-") + ): state_dict = remap_state_dict_hf_gptj(state_dict, config) - elif model_name.startswith("EleutherAI/gpt-neox-"): + elif ( + model_name.startswith("EleutherAI/gpt-neox-") + or model_name.startswith("EleutherAI/pythia-") + or model_name.startswith("togethercomputer/RedPajama-INCITE-") + ): state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) elif model_name.startswith("tiiuae/falcon-"): state_dict = remap_state_dict_hf_falcon(state_dict, config) diff --git a/tests/models/test_gpt_neox.py b/tests/models/test_gpt_neox.py index f4e27da..9ae8aa9 100644 --- a/tests/models/test_gpt_neox.py +++ b/tests/models/test_gpt_neox.py @@ -24,7 +24,15 @@ def test_gptj_state_dict(model_name): assert state_dict[k].shape == pretrained_state_dict[k].shape -@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"]) +@pytest.mark.parametrize( + "model_name", + [ + "EleutherAI/pythia-1b", + "EleutherAI/pythia-2.8b", + "EleutherAI/gpt-neox-20b", + "togethercomputer/RedPajama-INCITE-7B-Base", + ], +) def test_gpt_neox_optimized(model_name): """Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -35,7 +43,12 @@ def test_gpt_neox_optimized(model_name): config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) config.use_flash_attn = True config.fused_bias_fc = True - config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" + config.fused_mlp = config.activation_function in [ + "gelu_fast", + "gelu_new", + "gelu_approx", + "gelu_pytorch_tanh", + ] config.fused_dropout_add_ln = True config.residual_in_fp32 = True @@ -54,7 +67,7 @@ def test_gpt_neox_optimized(model_name): logits = model(input_ids).logits del model - # Need at least 2 GPUs, otherwise we'll OOM + # Need at least 2 GPUs, otherwise we'll OOM for the 20B model # Without device_map, the model is loaded on the CPU, which is very slow model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto") model_ref.eval() diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index d31aea4..7b23d61 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -23,7 +23,7 @@ def test_gptj_state_dict(model_name): assert state_dict[k].shape == pretrained_state_dict[k].shape -@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"]) +@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B", "togethercomputer/GPT-JT-6B-v1"]) def test_gptj_optimized(model_name): """Check that our implementation of GPT-J (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF