[SpecDecode] Support FlashInfer in DraftModelRunner (#6926)
This commit is contained in:
parent
82a1b1a82b
commit
e9630458c7
@ -11,6 +11,17 @@ except ModuleNotFoundError:
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
@ -79,6 +90,11 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
self.flashinfer_decode_workspace_buffer = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
self.flashinfer_prefill_workspace_buffer = None
|
||||
self.flashinfer_prefill_wrapper = None
|
||||
|
||||
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
||||
num_queries):
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
@ -286,6 +302,37 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
assert model_input.attn_metadata is not None
|
||||
assert model_input.input_tokens is not None
|
||||
if self.flashinfer_decode_workspace_buffer is None:
|
||||
self.flashinfer_decode_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_decode_wrapper = \
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_decode_workspace_buffer, "NHD")
|
||||
self.flashinfer_prefill_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_prefill_wrapper = \
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_prefill_workspace_buffer, "NHD")
|
||||
|
||||
model_input.attn_metadata.prefill_wrapper = \
|
||||
self.flashinfer_prefill_wrapper
|
||||
if model_input.attn_metadata.use_cuda_graph:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
model_input.attn_metadata.decode_wrapper = \
|
||||
self.graph_runners[model_input.
|
||||
virtual_engine][batch_size].flashinfer_decode_wrapper
|
||||
else:
|
||||
model_input.attn_metadata.decode_wrapper = \
|
||||
self.flashinfer_decode_wrapper
|
||||
model_input.attn_metadata.begin_forward()
|
||||
|
||||
# Detect exec mode
|
||||
assert model_input.attn_metadata is not None
|
||||
use_cuda_graph = False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user