[GPT] Add LLaMa-13B to test
This commit is contained in:
parent
8e9820a55b
commit
56ccaff126
@ -183,7 +183,7 @@ def test_llama_parallel(model_name, world_size):
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["7B"])
|
||||
@pytest.mark.parametrize('model_name', ["7B", "13B"])
|
||||
def test_llama_generation(model_name):
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
||||
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
@ -219,11 +219,12 @@ def test_llama_generation(model_name):
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
del model_hf
|
||||
|
||||
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
|
||||
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
|
||||
device_map={"": device})
|
||||
device_map='auto')
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
|
||||
del model_ref
|
||||
|
||||
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user