From 2110557dabe8a18b811116c1ae9fdf75fbe27df6 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 26 Jun 2024 21:12:10 -0700 Subject: [PATCH] [BugFix] Fix cuda graph for MLPSpeculator (#5875) Co-authored-by: Abhinav Goyal --- examples/offline_inference_mlpspeculator.py | 1 - vllm/worker/model_runner.py | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_mlpspeculator.py b/examples/offline_inference_mlpspeculator.py index 5448ec1f..5dec4a76 100644 --- a/examples/offline_inference_mlpspeculator.py +++ b/examples/offline_inference_mlpspeculator.py @@ -52,7 +52,6 @@ if __name__ == "__main__": speculative_model="ibm-fms/llama-13b-accelerator", # These are currently required for MLPSpeculator decoding use_v2_block_manager=True, - enforce_eager=True, ) print("With speculation") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9fdb2ea5..ac820bbc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1020,10 +1020,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): if self.return_hidden_states: # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + indices = model_input.sampling_metadata.selected_token_indices if model_input.is_prompt: - assert model_input.sampling_metadata is not None - hidden_states = hidden_states.index_select( - 0, model_input.sampling_metadata.selected_token_indices) + hidden_states = hidden_states.index_select(0, indices) + elif decode_meta.use_cuda_graph: + hidden_states = hidden_states[:len(indices)] + output.hidden_states = hidden_states return output