From 5d80a9178b48f211a6fa02a7b5ddc0a0ae29aa44 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Thu, 18 Jan 2024 09:40:34 -0800 Subject: [PATCH] Minor fix in prefill cache example (#2494) --- examples/offline_inference_with_prefix.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index df9f1364..8ccfb1ce 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -40,8 +40,16 @@ print("-" * 80) # -1 since the last token can change when concatenating prompts. prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 -# Generate with prefix -outputs = llm.generate(generating_prompts, sampling_params, +# The llm.generate call will batch all prompts and send the batch at once if resources allow. +# The prefix will only be cached after the first batch is processed, so we need to call generate once +# to calculate the prefix and cache it. +outputs = llm.generate(generating_prompts[0], + sampling_params, + prefix_pos=[prefix_pos]) + +# Subsequent batches can leverage the cached prefix +outputs = llm.generate(generating_prompts, + sampling_params, prefix_pos=[prefix_pos] * len(generating_prompts)) # Print the outputs. You should see the same outputs as before