From e9630458c7b11732e147c120817c53420280d471 Mon Sep 17 00:00:00 2001 From: Bongwon Jang <152451401+bong-furiosa@users.noreply.github.com> Date: Tue, 6 Aug 2024 00:05:05 +0900 Subject: [PATCH] [SpecDecode] Support FlashInfer in DraftModelRunner (#6926) --- vllm/spec_decode/draft_model_runner.py | 47 ++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 0b755600..b76a1ab4 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -11,6 +11,17 @@ except ModuleNotFoundError: from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -79,6 +90,11 @@ class TP1DraftModelRunner(ModelRunner): return_hidden_states=return_hidden_states, ) + self.flashinfer_decode_workspace_buffer = None + self.flashinfer_decode_wrapper = None + self.flashinfer_prefill_workspace_buffer = None + self.flashinfer_prefill_wrapper = None + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, num_queries): assert isinstance(attn_metadata, FlashAttentionMetadata) @@ -286,6 +302,37 @@ class TP1DraftModelRunner(ModelRunner): model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = \ + self.graph_runners[model_input. + virtual_engine][batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + # Detect exec mode assert model_input.attn_metadata is not None use_cuda_graph = False