[Gen] Remove commented code
This commit is contained in:
parent
b48599002a
commit
f95c2fc108
@ -97,7 +97,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
|||||||
if cg:
|
if cg:
|
||||||
assert fused_ft_kernel
|
assert fused_ft_kernel
|
||||||
run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
|
run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
|
||||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
|
|
||||||
if timing:
|
if timing:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while True:
|
while True:
|
||||||
@ -117,8 +116,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
|||||||
break
|
break
|
||||||
if timing:
|
if timing:
|
||||||
print(f'Decoding time: {time.time() - start}')
|
print(f'Decoding time: {time.time() - start}')
|
||||||
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))
|
|
||||||
# prof.export_chrome_trace("gpt2s_generation.json")
|
|
||||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||||
return output_cls(
|
return output_cls(
|
||||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user