From 56ccaff12678868c773cb9d4af7b309763173a7b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 26 Jul 2023 07:22:22 -1000 Subject: [PATCH] [GPT] Add LLaMa-13B to test --- tests/models/test_llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index a6b9c6c..36807a8 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -183,7 +183,7 @@ def test_llama_parallel(model_name, world_size): assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() -@pytest.mark.parametrize('model_name', ["7B"]) +@pytest.mark.parametrize('model_name', ["7B", "13B"]) def test_llama_generation(model_name): checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama' @@ -219,11 +219,12 @@ def test_llama_generation(model_name): print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') del model_hf + # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', - device_map={"": device}) + device_map='auto') model_ref.eval() with torch.no_grad(): - logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device) del model_ref ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)