[Gen] Refactor decode function a bit
This commit is contained in:
parent
371e20658c
commit
847abe653c
@ -124,6 +124,24 @@ def decode(
|
||||
inference_params = InferenceParams(
|
||||
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
|
||||
)
|
||||
|
||||
def logits_forward_fn(input_ids, position_ids, inference_params):
|
||||
if not cg:
|
||||
return model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
).logits
|
||||
else:
|
||||
return model._decoding_cache.run(
|
||||
input_ids, position_ids, inference_params.sequence_len_offset
|
||||
).clone()
|
||||
|
||||
logits_postprocess_fn = (
|
||||
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
|
||||
)
|
||||
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
if timing:
|
||||
@ -132,8 +150,7 @@ def decode(
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
logits = logits_postprocess_fn(logits)
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= seqlen_og:
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
@ -148,22 +165,10 @@ def decode(
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
if not cg:
|
||||
logits = model(
|
||||
rearrange(next_token, "b -> b 1"),
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
).logits
|
||||
else:
|
||||
logits = model._decoding_cache.run(
|
||||
rearrange(next_token, "b -> b 1"),
|
||||
position_ids,
|
||||
inference_params.sequence_len_offset,
|
||||
)
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
logits = logits_postprocess_fn(logits_forward_fn(
|
||||
rearrange(next_token, "b -> b 1"), position_ids, inference_params
|
||||
))
|
||||
scores.append(logits)
|
||||
if (
|
||||
teacher_outputs is None
|
||||
or teacher_output_len <= inference_params.sequence_len_offset + 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user