[Model] Support SDPA attention for Molmo vision backbone (#9410)
This commit is contained in:
parent
59230ef32b
commit
cf1d62a644
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from array import array
|
from array import array
|
||||||
@ -14,10 +13,8 @@ from torch import nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
from vllm.attention.selector import _Backend
|
||||||
get_global_forced_attn_backend)
|
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -43,12 +40,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.model_executor.models.utils import make_layers
|
from vllm.model_executor.models.utils import make_layers
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||||
SequenceData)
|
SequenceData)
|
||||||
from vllm.transformers_utils.processor import get_processor
|
from vllm.transformers_utils.processor import get_processor
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from .utils import get_vit_attn_backend
|
||||||
|
|
||||||
# TODO: hard-coded for now. Consider making it configurable.
|
# TODO: hard-coded for now. Consider making it configurable.
|
||||||
VIT_LAYERS = [-2, -9]
|
VIT_LAYERS = [-2, -9]
|
||||||
@ -190,35 +186,12 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
|
self.attn_backend: _Backend = get_vit_attn_backend()
|
||||||
if selected_backend is None:
|
if self.attn_backend not in {
|
||||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||||
if backend_by_env_var is not None:
|
}:
|
||||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
|
||||||
if selected_backend is None:
|
|
||||||
# For Volta and Turing GPUs, use xformers instead.
|
|
||||||
device_available = current_platform.get_device_capability()[0] >= 8
|
|
||||||
if device_available:
|
|
||||||
from transformers.utils import is_flash_attn_2_available
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
self._use_flash_attn = True
|
|
||||||
else:
|
|
||||||
log.warning(
|
|
||||||
"Current Molmo implementation has a bug with "
|
|
||||||
"`vllm-flash-attn` inside vision module, so we use "
|
|
||||||
"xformers backend instead. You can run `pip install "
|
|
||||||
"flash-attn to use flash-attention backend.")
|
|
||||||
self._use_flash_attn = False
|
|
||||||
else:
|
|
||||||
self._use_flash_attn = False
|
|
||||||
else:
|
|
||||||
if selected_backend == _Backend.FLASH_ATTN:
|
|
||||||
self._use_flash_attn = True
|
|
||||||
elif selected_backend == _Backend.XFORMERS:
|
|
||||||
self._use_flash_attn = False
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Molmo does not support {selected_backend} backend now.")
|
f"Molmo does not support {self.attn_backend} backend now.")
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
inputs_q: torch.Tensor,
|
inputs_q: torch.Tensor,
|
||||||
@ -240,10 +213,15 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
xk = xk.view(*kv_shape)
|
xk = xk.view(*kv_shape)
|
||||||
xv = xv.view(*kv_shape)
|
xv = xv.view(*kv_shape)
|
||||||
|
|
||||||
if self._use_flash_attn:
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
|
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
|
||||||
else:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
|
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
|
||||||
|
for x in (xq, xk, xv))
|
||||||
|
output = F.scaled_dot_product_attention(xq, xk, xv)
|
||||||
|
output = rearrange(output, "b h s d -> b s h d ")
|
||||||
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
|
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
|
||||||
|
|
||||||
|
|||||||
@ -39,10 +39,8 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
|||||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||||
make_batched_images, make_batched_videos, smart_resize)
|
make_batched_images, make_batched_videos, smart_resize)
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
from vllm.attention.selector import _Backend
|
||||||
get_global_forced_attn_backend)
|
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.distributed import get_pp_group, parallel_state
|
from vllm.distributed import get_pp_group, parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@ -63,14 +61,13 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
|||||||
MultiModalInputs)
|
MultiModalInputs)
|
||||||
from vllm.multimodal.base import MultiModalData
|
from vllm.multimodal.base import MultiModalData
|
||||||
from vllm.multimodal.image import cached_get_image_processor
|
from vllm.multimodal.image import cached_get_image_processor
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
from vllm.transformers_utils.processor import get_processor
|
from vllm.transformers_utils.processor import get_processor
|
||||||
from vllm.utils import is_cpu
|
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (PPMissingLayer, get_vit_attn_backend,
|
||||||
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory)
|
make_empty_intermediate_tensors_factory)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -215,37 +212,12 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
|
self.attn_backend: _Backend = get_vit_attn_backend()
|
||||||
if selected_backend is None:
|
if self.attn_backend not in {
|
||||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||||
if backend_by_env_var is not None:
|
}:
|
||||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
|
||||||
if selected_backend is None:
|
|
||||||
# For Volta and Turing GPUs, use xformers instead.
|
|
||||||
device_available = current_platform.has_device_capability(80)
|
|
||||||
if device_available:
|
|
||||||
from transformers.utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
self._use_flash_attn = True
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Current Qwen2-VL implementation has a bug with "
|
|
||||||
"`vllm-flash-attn` inside vision module, so we use "
|
|
||||||
"xformers backend instead. You can run `pip install "
|
|
||||||
"flash-attn to use flash-attention backend.")
|
|
||||||
self._use_flash_attn = False
|
|
||||||
else:
|
|
||||||
self._use_flash_attn = False
|
|
||||||
else:
|
|
||||||
if selected_backend == _Backend.FLASH_ATTN:
|
|
||||||
self._use_flash_attn = True
|
|
||||||
elif selected_backend == _Backend.XFORMERS:
|
|
||||||
self._use_flash_attn = False
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Qwen2-VL does not support {selected_backend} backend now."
|
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -274,7 +246,7 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||||
|
|
||||||
if self._use_flash_attn:
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
# from vllm_flash_attn.flash_attn_interface import (
|
# from vllm_flash_attn.flash_attn_interface import (
|
||||||
# flash_attn_varlen_func)
|
# flash_attn_varlen_func)
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
@ -295,7 +267,7 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
context_layer = rearrange(output,
|
context_layer = rearrange(output,
|
||||||
"(b s) ... -> b s ...",
|
"(b s) ... -> b s ...",
|
||||||
b=batch_size)
|
b=batch_size)
|
||||||
elif is_cpu():
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
seq_length = q.size(1)
|
seq_length = q.size(1)
|
||||||
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
|
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
|
||||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||||
@ -310,7 +282,7 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
dropout_p=0.0)
|
dropout_p=0.0)
|
||||||
context_layer = rearrange(output, "b h s d -> b s h d ")
|
context_layer = rearrange(output, "b h s d -> b s h d ")
|
||||||
else:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
|
|||||||
@ -8,15 +8,22 @@ import torch.nn as nn
|
|||||||
from torch.func import functional_call
|
from torch.func import functional_call
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
||||||
|
get_global_forced_attn_backend)
|
||||||
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.loader import build_model
|
from vllm.model_executor.model_loader.loader import build_model
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.multimodal.base import NestedTensors
|
from vllm.multimodal.base import NestedTensors
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_cpu, is_pin_memory_available
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
WeightsMapping = Mapping[str, Optional[str]]
|
WeightsMapping = Mapping[str, Optional[str]]
|
||||||
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
|
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
|
||||||
@ -487,3 +494,29 @@ class LLMWrapper(nn.Module):
|
|||||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
llm = super().__getattr__(self.model_name)
|
llm = super().__getattr__(self.model_name)
|
||||||
return llm(*args, **kwargs)
|
return llm(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_vit_attn_backend() -> _Backend:
|
||||||
|
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
|
||||||
|
if selected_backend is None:
|
||||||
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
if backend_by_env_var is not None:
|
||||||
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
|
if selected_backend is None:
|
||||||
|
# For Volta and Turing GPUs, use xformers instead.
|
||||||
|
device_available = current_platform.has_device_capability(80)
|
||||||
|
if device_available:
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
selected_backend = _Backend.FLASH_ATTN
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Current `vllm-flash-attn` has a bug inside vision module, "
|
||||||
|
"so we use xformers backend instead. You can run "
|
||||||
|
"`pip install flash-attn` to use flash-attention backend.")
|
||||||
|
selected_backend = _Backend.XFORMERS
|
||||||
|
elif is_cpu():
|
||||||
|
selected_backend = _Backend.TORCH_SDPA
|
||||||
|
else:
|
||||||
|
selected_backend = _Backend.XFORMERS
|
||||||
|
return selected_backend
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user