[Gen] Measure prompt processing + decoding time, not just decoding
This commit is contained in:
parent
6f6e9a9aaf
commit
1c9ef9b399
@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
break
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user