[Bugfix] Fix PaliGemma MMP (#6930)
This commit is contained in:
parent
6e063ea35b
commit
c66c7f86ac
@ -9,7 +9,6 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.linear = ColumnParallelLinear(vision_hidden_size,
|
||||
projection_dim,
|
||||
bias=True)
|
||||
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.linear(image_features)
|
||||
hidden_states = self.linear(image_features)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user