[Bugfix] Fix PaliGemma MMP (#6930)

This commit is contained in:
Roger Wang 2024-07-30 02:20:57 -07:00 committed by GitHub
parent 6e063ea35b
commit c66c7f86ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,6 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
self.linear = ColumnParallelLinear(vision_hidden_size,
projection_dim,
bias=True)
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear(image_features)
hidden_states = self.linear(image_features)
return hidden_states