[Gen] Refactor decode function a bit

This commit is contained in:
Tri Dao 2023-08-26 14:36:44 -07:00
parent 371e20658c
commit 847abe653c

View File

@ -124,6 +124,24 @@ def decode(
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
def logits_forward_fn(input_ids, position_ids, inference_params):
if not cg:
return model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
else:
return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
).clone()
logits_postprocess_fn = (
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
)
scores = []
with torch.inference_mode():
if timing:
@ -132,8 +150,7 @@ def decode(
torch.cuda.synchronize()
start = time.time()
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
if vocab_size is not None:
logits = logits[..., :vocab_size]
logits = logits_postprocess_fn(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
@ -148,22 +165,10 @@ def decode(
dtype=torch.long,
device=input_ids.device,
)
if not cg:
logits = model(
rearrange(next_token, "b -> b 1"),
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
else:
logits = model._decoding_cache.run(
rearrange(next_token, "b -> b 1"),
position_ids,
inference_params.sequence_len_offset,
)
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone())
logits = logits_postprocess_fn(logits_forward_fn(
rearrange(next_token, "b -> b 1"), position_ids, inference_params
))
scores.append(logits)
if (
teacher_outputs is None
or teacher_output_len <= inference_params.sequence_len_offset + 1