diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 3abc94d..a89c414 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -84,6 +84,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): ) +@torch.inference_mode() def decode( input_ids, model, @@ -97,7 +98,7 @@ def decode( tensor_parallel=1, fused_ft_kernel=False, cg=False, - timing=False, + enable_timing=False, ): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). @@ -137,73 +138,67 @@ def decode( max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel ) - def logits_forward_fn(input_ids, position_ids, inference_params): - if not cg: - return model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=1, - ).logits.squeeze(dim=1) - else: - return model._decoding_cache.run( - input_ids, position_ids, inference_params.sequence_len_offset - ).clone() - - logits_postprocess_fn = ( - lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits - ) - - scores = [] - with torch.inference_mode(): - if timing: - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - start = time.time() - logits = model( - input_ids, inference_params=inference_params, num_last_tokens=1 - ).logits.squeeze(dim=1) - logits = logits_postprocess_fn(logits) - scores.append(logits if not cg else logits.clone()) - if teacher_outputs is None or teacher_output_len <= seqlen_og: - next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) - else: - next_token = teacher_outputs[:, seqlen_og] - sequences = [next_token] - inference_params.sequence_len_offset = seqlen_og - while True: + def get_logits(input_ids, inference_params): + decoding = inference_params.sequence_len_offset > 0 + if decoding: position_ids = torch.full( (batch_size, 1), inference_params.sequence_len_offset, dtype=torch.long, device=input_ids.device, ) - logits = logits_postprocess_fn( - logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params) - ) - scores.append(logits) - if ( - teacher_outputs is None - or teacher_output_len <= inference_params.sequence_len_offset + 1 - ): - next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) - else: - next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1] - sequences.append(next_token) - inference_params.sequence_len_offset += 1 - if eos_token_id is not None and (next_token == eos_token_id).all(): - break - if inference_params.sequence_len_offset >= max_length - 1: - break - if timing: - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.sequence_len_offset + ).clone() + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset: + token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.sequence_len_offset] + return rearrange(token, "b -> b 1") + + def should_stop(current_token, inference_params): + if inference_params.sequence_len_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.sequence_len_offset >= max_length - 1: + return True + return False + + start = torch.cuda.Event(enable_timing=enable_timing) + end = torch.cuda.Event(enable_timing=enable_timing) + + if enable_timing: + if tensor_parallel > 1: + torch.distributed.barrier() + start.record() + scores, sequences = [], [input_ids] + while not should_stop(sequences[-1], inference_params): + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.sequence_len_offset += sequences[-1].shape[1] + sequences.append(sample_tokens(scores[-1], inference_params)) + if enable_timing: + end.record() + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls( - sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores) + sequences=torch.cat(sequences, dim=1), scores=tuple(scores) ) @@ -280,7 +275,7 @@ def decode_speculative( tensor_parallel=1, fused_ft_kernel=False, cg=False, - timing=False, + enable_timing=False, debug=False, ): """ @@ -446,7 +441,7 @@ def decode_speculative( sequences = [input_ids] scores = [] with torch.inference_mode(): - if timing: + if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() @@ -566,7 +561,7 @@ def decode_speculative( ).logits print((scores[-1] - scores_ref[:, :-1]).abs().max()) - if timing: + if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index baf00c9..1ff6ea7 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -289,7 +289,7 @@ def test_baichuan_generation(model_name): fused_ft_kernel=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() @@ -310,7 +310,7 @@ def test_baichuan_generation(model_name): cg=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() @@ -400,7 +400,7 @@ def test_baichuan_parallel_generation(model_name, world_size): # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) # Capture graph outside the timing loop @@ -419,7 +419,7 @@ def test_baichuan_parallel_generation(model_name, world_size): # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) del model parallel_state.destroy_model_parallel() diff --git a/tests/models/test_falcon.py b/tests/models/test_falcon.py index 9255d7b..ecb95fd 100644 --- a/tests/models/test_falcon.py +++ b/tests/models/test_falcon.py @@ -245,7 +245,7 @@ def test_falcon_generation(model_name): fused_ft_kernel=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() @@ -264,7 +264,7 @@ def test_falcon_generation(model_name): cg=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() @@ -351,7 +351,7 @@ def test_falcon_parallel_generation(model_name, world_size): # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) # Capture graph outside the timing loop @@ -368,7 +368,7 @@ def test_falcon_parallel_generation(model_name, world_size): # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) del model parallel_state.destroy_model_parallel() diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 93811df..8f74e93 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -200,7 +200,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): fused_ft_kernel=fused_ft_kernel, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) print(out.sequences) print(tokenizer.batch_decode(out.sequences.tolist())) @@ -212,7 +212,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): cg=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) print(out_cg.sequences) @@ -267,7 +267,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): teacher_outputs=teacher_outputs, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, **kwargs, ) return torch.stack(out.scores, dim=1) @@ -431,7 +431,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): fused_ft_kernel=fused_ft_kernel, cg=cg, speculative_lookahead=4, - timing=True, + enable_timing=True, ) print(tokenizer.batch_decode(out.sequences)) out_og = model.generate( @@ -440,7 +440,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): top_k=5, fused_ft_kernel=fused_ft_kernel, cg=False, - timing=True, + enable_timing=True, return_dict_in_generate=True, ) print(tokenizer.batch_decode(out_og.sequences)) diff --git a/tests/models/test_gpt_generation_parallel.py b/tests/models/test_gpt_generation_parallel.py index 3fbc5da..ec25868 100644 --- a/tests/models/test_gpt_generation_parallel.py +++ b/tests/models/test_gpt_generation_parallel.py @@ -114,7 +114,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): fused_ft_kernel=fused_ft_kernel, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) print(out.sequences) if fused_ft_kernel: @@ -127,7 +127,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): cg=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) print(out_cg.sequences) diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index 76fc7c3..8abb3b9 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -144,7 +144,7 @@ def test_gptj_generation(model_name): # eos_token_id=eos_token_id, fused_ft_kernel=False, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() @@ -163,7 +163,7 @@ def test_gptj_generation(model_name): cg=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 48882b4..3b162ba 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -295,7 +295,7 @@ def test_llama_generation(model_name, checkpoint_format): fused_ft_kernel=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() @@ -314,7 +314,7 @@ def test_llama_generation(model_name, checkpoint_format): cg=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() @@ -403,7 +403,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) # Capture graph outside the timing loop @@ -420,7 +420,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) del model parallel_state.destroy_model_parallel() diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py index bd3e9a9..535b76c 100644 --- a/tests/models/test_opt.py +++ b/tests/models/test_opt.py @@ -158,7 +158,7 @@ def test_opt_generation(model_name): fused_ft_kernel=fused_ft_kernel, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") @@ -179,7 +179,7 @@ def test_opt_generation(model_name): cg=True, return_dict_in_generate=True, output_scores=True, - timing=True, + enable_timing=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")