[Bugfix] Bandaid fix for speculative decoding tests (#9327)
This commit is contained in:
parent
f519902c52
commit
16b24e7dcd
@ -17,6 +17,7 @@ import torch.nn as nn
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionState
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
@ -1001,6 +1002,17 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
|
||||
# Attention-free but stateful models like Mamba need a placeholder attn
|
||||
# backend, as the attention metadata is needed to manage internal state.
|
||||
# However we must bypass attention selection altogether for some models
|
||||
# used for speculative decoding to avoid a divide-by-zero in
|
||||
# model_config.get_head_size()
|
||||
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
needs_attn_backend = (num_attn_heads != 0
|
||||
or self.model_config.is_attention_free)
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_sliding_window(),
|
||||
@ -1008,9 +1020,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
) if needs_attn_backend else None
|
||||
if self.attn_backend:
|
||||
self.attn_state = self.attn_backend.get_state_cls()(
|
||||
weakref.proxy(self))
|
||||
else:
|
||||
self.attn_state = CommonAttentionState(weakref.proxy(self))
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
|
||||
Loading…
Reference in New Issue
Block a user