[Gen] Remove commented code

This commit is contained in:
Tri Dao 2023-01-07 19:06:39 -08:00
parent b48599002a
commit f95c2fc108

View File

@ -97,7 +97,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if cg: if cg:
assert fused_ft_kernel assert fused_ft_kernel
run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length) run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
if timing: if timing:
start = time.time() start = time.time()
while True: while True:
@ -117,8 +116,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
break break
if timing: if timing:
print(f'Decoding time: {time.time() - start}') print(f'Decoding time: {time.time() - start}')
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))
# prof.export_chrome_trace("gpt2s_generation.json")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls( return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),