[Gen] Refactor decoding function

This commit is contained in:
Tri Dao 2023-09-04 17:01:38 -07:00
parent 3557e0bb8f
commit 913922cac5
8 changed files with 81 additions and 86 deletions

View File

@ -84,6 +84,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
)
@torch.inference_mode()
def decode(
input_ids,
model,
@ -97,7 +98,7 @@ def decode(
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
timing=False,
enable_timing=False,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
@ -137,73 +138,67 @@ def decode(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
def logits_forward_fn(input_ids, position_ids, inference_params):
if not cg:
return model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
).clone()
logits_postprocess_fn = (
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
)
scores = []
with torch.inference_mode():
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
logits = model(
input_ids, inference_params=inference_params, num_last_tokens=1
).logits.squeeze(dim=1)
logits = logits_postprocess_fn(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
next_token = teacher_outputs[:, seqlen_og]
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
def get_logits(input_ids, inference_params):
decoding = inference_params.sequence_len_offset > 0
if decoding:
position_ids = torch.full(
(batch_size, 1),
inference_params.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
logits = logits_postprocess_fn(
logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params)
)
scores.append(logits)
if (
teacher_outputs is None
or teacher_output_len <= inference_params.sequence_len_offset + 1
):
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if eos_token_id is not None and (next_token == eos_token_id).all():
break
if inference_params.sequence_len_offset >= max_length - 1:
break
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
else:
position_ids = None
if not cg or not decoding:
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
).clone()
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params):
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset:
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
token = teacher_outputs[:, inference_params.sequence_len_offset]
return rearrange(token, "b -> b 1")
def should_stop(current_token, inference_params):
if inference_params.sequence_len_offset == 0:
return False
if eos_token_id is not None and (current_token == eos_token_id).all():
return True
if inference_params.sequence_len_offset >= max_length - 1:
return True
return False
start = torch.cuda.Event(enable_timing=enable_timing)
end = torch.cuda.Event(enable_timing=enable_timing)
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
start.record()
scores, sequences = [], [input_ids]
while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params))
inference_params.sequence_len_offset += sequences[-1].shape[1]
sequences.append(sample_tokens(scores[-1], inference_params))
if enable_timing:
end.record()
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
sequences=torch.cat(sequences, dim=1), scores=tuple(scores)
)
@ -280,7 +275,7 @@ def decode_speculative(
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
timing=False,
enable_timing=False,
debug=False,
):
"""
@ -446,7 +441,7 @@ def decode_speculative(
sequences = [input_ids]
scores = []
with torch.inference_mode():
if timing:
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
@ -566,7 +561,7 @@ def decode_speculative(
).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
if timing:
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()

View File

@ -289,7 +289,7 @@ def test_baichuan_generation(model_name):
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
@ -310,7 +310,7 @@ def test_baichuan_generation(model_name):
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
@ -400,7 +400,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
# Capture graph outside the timing loop
@ -419,7 +419,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
del model
parallel_state.destroy_model_parallel()

View File

@ -245,7 +245,7 @@ def test_falcon_generation(model_name):
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
@ -264,7 +264,7 @@ def test_falcon_generation(model_name):
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
@ -351,7 +351,7 @@ def test_falcon_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
# Capture graph outside the timing loop
@ -368,7 +368,7 @@ def test_falcon_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
del model
parallel_state.destroy_model_parallel()

View File

@ -200,7 +200,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
@ -212,7 +212,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
print(out_cg.sequences)
@ -267,7 +267,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
teacher_outputs=teacher_outputs,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
**kwargs,
)
return torch.stack(out.scores, dim=1)
@ -431,7 +431,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
fused_ft_kernel=fused_ft_kernel,
cg=cg,
speculative_lookahead=4,
timing=True,
enable_timing=True,
)
print(tokenizer.batch_decode(out.sequences))
out_og = model.generate(
@ -440,7 +440,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
top_k=5,
fused_ft_kernel=fused_ft_kernel,
cg=False,
timing=True,
enable_timing=True,
return_dict_in_generate=True,
)
print(tokenizer.batch_decode(out_og.sequences))

View File

@ -114,7 +114,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
print(out.sequences)
if fused_ft_kernel:
@ -127,7 +127,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
print(out_cg.sequences)

View File

@ -144,7 +144,7 @@ def test_gptj_generation(model_name):
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
@ -163,7 +163,7 @@ def test_gptj_generation(model_name):
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()

View File

@ -295,7 +295,7 @@ def test_llama_generation(model_name, checkpoint_format):
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
@ -314,7 +314,7 @@ def test_llama_generation(model_name, checkpoint_format):
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
@ -403,7 +403,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
# Capture graph outside the timing loop
@ -420,7 +420,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
del model
parallel_state.destroy_model_parallel()

View File

@ -158,7 +158,7 @@ def test_opt_generation(model_name):
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
@ -179,7 +179,7 @@ def test_opt_generation(model_name):
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
enable_timing=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")