[GPT] Add LLaMa-13B to test
This commit is contained in:
parent
8e9820a55b
commit
56ccaff126
@ -183,7 +183,7 @@ def test_llama_parallel(model_name, world_size):
|
|||||||
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
|
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('model_name', ["7B"])
|
@pytest.mark.parametrize('model_name', ["7B", "13B"])
|
||||||
def test_llama_generation(model_name):
|
def test_llama_generation(model_name):
|
||||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
||||||
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||||
@ -219,11 +219,12 @@ def test_llama_generation(model_name):
|
|||||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||||
del model_hf
|
del model_hf
|
||||||
|
|
||||||
|
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
|
||||||
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
|
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
|
||||||
device_map={"": device})
|
device_map='auto')
|
||||||
model_ref.eval()
|
model_ref.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
|
||||||
del model_ref
|
del model_ref
|
||||||
|
|
||||||
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
|
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user