diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 8979644..b180e8e 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=fused_ft_kernel) scores = [] with torch.inference_mode(): - logits = model(input_ids, inference_params=inference_params).logits[:, -1] if timing: + if tensor_parallel > 1: + torch.distributed.barrier() torch.cuda.synchronize() start = time.time() + logits = model(input_ids, inference_params=inference_params).logits[:, -1] if vocab_size is not None: logits = logits[..., :vocab_size] scores.append(logits if not cg else logits.clone()) @@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, if inference_params.sequence_len_offset >= max_length - 1: break if timing: + if tensor_parallel > 1: + torch.distributed.barrier() torch.cuda.synchronize() - print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms') + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls( sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),