[Gen] Refactor decoding function
This commit is contained in:
parent
3557e0bb8f
commit
913922cac5
@ -84,6 +84,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(
|
||||
input_ids,
|
||||
model,
|
||||
@ -97,7 +98,7 @@ def decode(
|
||||
tensor_parallel=1,
|
||||
fused_ft_kernel=False,
|
||||
cg=False,
|
||||
timing=False,
|
||||
enable_timing=False,
|
||||
):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
@ -137,73 +138,67 @@ def decode(
|
||||
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
|
||||
)
|
||||
|
||||
def logits_forward_fn(input_ids, position_ids, inference_params):
|
||||
if not cg:
|
||||
return model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
num_last_tokens=1,
|
||||
).logits.squeeze(dim=1)
|
||||
else:
|
||||
return model._decoding_cache.run(
|
||||
input_ids, position_ids, inference_params.sequence_len_offset
|
||||
).clone()
|
||||
|
||||
logits_postprocess_fn = (
|
||||
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
|
||||
)
|
||||
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(
|
||||
input_ids, inference_params=inference_params, num_last_tokens=1
|
||||
).logits.squeeze(dim=1)
|
||||
logits = logits_postprocess_fn(logits)
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= seqlen_og:
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, seqlen_og]
|
||||
sequences = [next_token]
|
||||
inference_params.sequence_len_offset = seqlen_og
|
||||
while True:
|
||||
def get_logits(input_ids, inference_params):
|
||||
decoding = inference_params.sequence_len_offset > 0
|
||||
if decoding:
|
||||
position_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
inference_params.sequence_len_offset,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
logits = logits_postprocess_fn(
|
||||
logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params)
|
||||
)
|
||||
scores.append(logits)
|
||||
if (
|
||||
teacher_outputs is None
|
||||
or teacher_output_len <= inference_params.sequence_len_offset + 1
|
||||
):
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
|
||||
sequences.append(next_token)
|
||||
inference_params.sequence_len_offset += 1
|
||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||
break
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
break
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
else:
|
||||
position_ids = None
|
||||
if not cg or not decoding:
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
num_last_tokens=1,
|
||||
).logits.squeeze(dim=1)
|
||||
else:
|
||||
logits = model._decoding_cache.run(
|
||||
input_ids, position_ids, inference_params.sequence_len_offset
|
||||
).clone()
|
||||
return logits[..., :vocab_size] if vocab_size is not None else logits
|
||||
|
||||
def sample_tokens(logits, inference_params):
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset:
|
||||
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
token = teacher_outputs[:, inference_params.sequence_len_offset]
|
||||
return rearrange(token, "b -> b 1")
|
||||
|
||||
def should_stop(current_token, inference_params):
|
||||
if inference_params.sequence_len_offset == 0:
|
||||
return False
|
||||
if eos_token_id is not None and (current_token == eos_token_id).all():
|
||||
return True
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
start = torch.cuda.Event(enable_timing=enable_timing)
|
||||
end = torch.cuda.Event(enable_timing=enable_timing)
|
||||
|
||||
if enable_timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
start.record()
|
||||
scores, sequences = [], [input_ids]
|
||||
while not should_stop(sequences[-1], inference_params):
|
||||
scores.append(get_logits(sequences[-1], inference_params))
|
||||
inference_params.sequence_len_offset += sequences[-1].shape[1]
|
||||
sequences.append(sample_tokens(scores[-1], inference_params))
|
||||
if enable_timing:
|
||||
end.record()
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
|
||||
sequences=torch.cat(sequences, dim=1), scores=tuple(scores)
|
||||
)
|
||||
|
||||
|
||||
@ -280,7 +275,7 @@ def decode_speculative(
|
||||
tensor_parallel=1,
|
||||
fused_ft_kernel=False,
|
||||
cg=False,
|
||||
timing=False,
|
||||
enable_timing=False,
|
||||
debug=False,
|
||||
):
|
||||
"""
|
||||
@ -446,7 +441,7 @@ def decode_speculative(
|
||||
sequences = [input_ids]
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
if timing:
|
||||
if enable_timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
@ -566,7 +561,7 @@ def decode_speculative(
|
||||
).logits
|
||||
print((scores[-1] - scores_ref[:, :-1]).abs().max())
|
||||
|
||||
if timing:
|
||||
if enable_timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -289,7 +289,7 @@ def test_baichuan_generation(model_name):
|
||||
fused_ft_kernel=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -310,7 +310,7 @@ def test_baichuan_generation(model_name):
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -400,7 +400,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
@ -419,7 +419,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
del model
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
@ -245,7 +245,7 @@ def test_falcon_generation(model_name):
|
||||
fused_ft_kernel=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -264,7 +264,7 @@ def test_falcon_generation(model_name):
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -351,7 +351,7 @@ def test_falcon_parallel_generation(model_name, world_size):
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
@ -368,7 +368,7 @@ def test_falcon_parallel_generation(model_name, world_size):
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
del model
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
@ -200,7 +200,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
print(out.sequences)
|
||||
print(tokenizer.batch_decode(out.sequences.tolist()))
|
||||
@ -212,7 +212,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
print(out_cg.sequences)
|
||||
|
||||
@ -267,7 +267,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
|
||||
teacher_outputs=teacher_outputs,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
**kwargs,
|
||||
)
|
||||
return torch.stack(out.scores, dim=1)
|
||||
@ -431,7 +431,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
cg=cg,
|
||||
speculative_lookahead=4,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
print(tokenizer.batch_decode(out.sequences))
|
||||
out_og = model.generate(
|
||||
@ -440,7 +440,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
|
||||
top_k=5,
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
cg=False,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
print(tokenizer.batch_decode(out_og.sequences))
|
||||
|
||||
@ -114,7 +114,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
print(out.sequences)
|
||||
if fused_ft_kernel:
|
||||
@ -127,7 +127,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
print(out_cg.sequences)
|
||||
|
||||
|
||||
@ -144,7 +144,7 @@ def test_gptj_generation(model_name):
|
||||
# eos_token_id=eos_token_id, fused_ft_kernel=False,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -163,7 +163,7 @@ def test_gptj_generation(model_name):
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -295,7 +295,7 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
fused_ft_kernel=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -314,7 +314,7 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -403,7 +403,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
@ -420,7 +420,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
del model
|
||||
parallel_state.destroy_model_parallel()
|
||||
|
||||
@ -158,7 +158,7 @@ def test_opt_generation(model_name):
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
@ -179,7 +179,7 @@ def test_opt_generation(model_name):
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
enable_timing=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user