[SpecDecode] Support FlashInfer in DraftModelRunner (#6926)
This commit is contained in:
parent
82a1b1a82b
commit
e9630458c7
@ -11,6 +11,17 @@ except ModuleNotFoundError:
|
|||||||
from vllm.attention.backends.rocm_flash_attn import (
|
from vllm.attention.backends.rocm_flash_attn import (
|
||||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||||
|
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||||
|
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
|
except ImportError:
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper = None
|
||||||
|
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper = None
|
||||||
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig)
|
PromptAdapterConfig, SchedulerConfig)
|
||||||
@ -79,6 +90,11 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
return_hidden_states=return_hidden_states,
|
return_hidden_states=return_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.flashinfer_decode_workspace_buffer = None
|
||||||
|
self.flashinfer_decode_wrapper = None
|
||||||
|
self.flashinfer_prefill_workspace_buffer = None
|
||||||
|
self.flashinfer_prefill_wrapper = None
|
||||||
|
|
||||||
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
||||||
num_queries):
|
num_queries):
|
||||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||||
@ -286,6 +302,37 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
model_input.prompt_adapter_requests,
|
model_input.prompt_adapter_requests,
|
||||||
model_input.prompt_adapter_mapping)
|
model_input.prompt_adapter_mapping)
|
||||||
|
|
||||||
|
if self.attn_backend.get_name() == "flashinfer":
|
||||||
|
assert model_input.attn_metadata is not None
|
||||||
|
assert model_input.input_tokens is not None
|
||||||
|
if self.flashinfer_decode_workspace_buffer is None:
|
||||||
|
self.flashinfer_decode_workspace_buffer = torch.empty(
|
||||||
|
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device)
|
||||||
|
self.flashinfer_decode_wrapper = \
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.flashinfer_decode_workspace_buffer, "NHD")
|
||||||
|
self.flashinfer_prefill_workspace_buffer = torch.empty(
|
||||||
|
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device)
|
||||||
|
self.flashinfer_prefill_wrapper = \
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.flashinfer_prefill_workspace_buffer, "NHD")
|
||||||
|
|
||||||
|
model_input.attn_metadata.prefill_wrapper = \
|
||||||
|
self.flashinfer_prefill_wrapper
|
||||||
|
if model_input.attn_metadata.use_cuda_graph:
|
||||||
|
batch_size = model_input.input_tokens.shape[0]
|
||||||
|
model_input.attn_metadata.decode_wrapper = \
|
||||||
|
self.graph_runners[model_input.
|
||||||
|
virtual_engine][batch_size].flashinfer_decode_wrapper
|
||||||
|
else:
|
||||||
|
model_input.attn_metadata.decode_wrapper = \
|
||||||
|
self.flashinfer_decode_wrapper
|
||||||
|
model_input.attn_metadata.begin_forward()
|
||||||
|
|
||||||
# Detect exec mode
|
# Detect exec mode
|
||||||
assert model_input.attn_metadata is not None
|
assert model_input.attn_metadata is not None
|
||||||
use_cuda_graph = False
|
use_cuda_graph = False
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user