[Bugfix] Fix PaliGemma MMP (#6930)
This commit is contained in:
parent
6e063ea35b
commit
c66c7f86ac
@ -9,7 +9,6 @@ from vllm.attention import AttentionMetadata
|
|||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
|||||||
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.linear = ColumnParallelLinear(vision_hidden_size,
|
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
|
||||||
projection_dim,
|
|
||||||
bias=True)
|
|
||||||
|
|
||||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.linear(image_features)
|
hidden_states = self.linear(image_features)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user