[GPT] Add LLaMa-13B to test

This commit is contained in:
Tri Dao 2023-07-26 07:22:22 -10:00
parent 8e9820a55b
commit 56ccaff126

View File

@ -183,7 +183,7 @@ def test_llama_parallel(model_name, world_size):
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["7B"])
@pytest.mark.parametrize('model_name', ["7B", "13B"])
def test_llama_generation(model_name):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
@ -219,11 +219,12 @@ def test_llama_generation(model_name):
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
del model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
device_map={"": device})
device_map='auto')
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
del model_ref
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)