diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index fa40a64..ea01ea1 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -97,7 +97,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, if cg: assert fused_ft_kernel run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length) - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: if timing: start = time.time() while True: @@ -117,8 +116,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, break if timing: print(f'Decoding time: {time.time() - start}') - # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50)) - # prof.export_chrome_trace("gpt2s_generation.json") output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls( sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),