[Gen] Measure prompt processing + decoding time, not just decoding

This commit is contained in:
Tri Dao 2023-04-13 15:39:56 -07:00
parent 6f6e9a9aaf
commit 1c9ef9b399

View File

@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel=fused_ft_kernel)
scores = []
with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone())
@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if inference_params.sequence_len_offset >= max_length - 1:
break
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),