From 847abe653c885fe96820d6963768ddc1ab6a519b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 26 Aug 2023 14:36:44 -0700 Subject: [PATCH] [Gen] Refactor decode function a bit --- flash_attn/utils/generation.py | 41 +++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index ab05376..ed774d8 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -124,6 +124,24 @@ def decode( inference_params = InferenceParams( max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel ) + + def logits_forward_fn(input_ids, position_ids, inference_params): + if not cg: + return model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + last_token_only=True, + ).logits + else: + return model._decoding_cache.run( + input_ids, position_ids, inference_params.sequence_len_offset + ).clone() + + logits_postprocess_fn = ( + lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits + ) + scores = [] with torch.inference_mode(): if timing: @@ -132,8 +150,7 @@ def decode( torch.cuda.synchronize() start = time.time() logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits - if vocab_size is not None: - logits = logits[..., :vocab_size] + logits = logits_postprocess_fn(logits) scores.append(logits if not cg else logits.clone()) if teacher_outputs is None or teacher_output_len <= seqlen_og: next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) @@ -148,22 +165,10 @@ def decode( dtype=torch.long, device=input_ids.device, ) - if not cg: - logits = model( - rearrange(next_token, "b -> b 1"), - position_ids=position_ids, - inference_params=inference_params, - last_token_only=True, - ).logits - else: - logits = model._decoding_cache.run( - rearrange(next_token, "b -> b 1"), - position_ids, - inference_params.sequence_len_offset, - ) - if vocab_size is not None: - logits = logits[..., :vocab_size] - scores.append(logits if not cg else logits.clone()) + logits = logits_postprocess_fn(logits_forward_fn( + rearrange(next_token, "b -> b 1"), position_ids, inference_params + )) + scores.append(logits) if ( teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1