[Model] Add BNB quantization support for Idefics3 (#10310)
Signed-off-by: B-201 <Joy25810@foxmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
52b48c1ead
commit
294bf467ba
@ -22,6 +22,7 @@ import torch.utils.checkpoint
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
# Temporary solution for transformers below 4.46.0.
|
||||
from transformers import PretrainedConfig as Idefics3Config
|
||||
from transformers import ProcessorMixin as Idefics3ImageProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
@ -31,6 +32,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -374,12 +376,23 @@ def dummy_data_for_idefics3(
|
||||
|
||||
class Idefics3SimpleMLP(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics3Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
input_size = config.vision_config.hidden_size * (config.scale_factor**
|
||||
2)
|
||||
output_size = config.text_config.hidden_size
|
||||
self.proj = ReplicatedLinear(input_size, output_size, bias=False)
|
||||
self.proj = ReplicatedLinear(
|
||||
input_size,
|
||||
output_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "proj"),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out, _ = self.proj(x)
|
||||
@ -388,10 +401,19 @@ class Idefics3SimpleMLP(nn.Module):
|
||||
|
||||
class Idefics3Connector(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics3Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.scale_factor = config.scale_factor
|
||||
self.modality_projection = Idefics3SimpleMLP(config)
|
||||
self.modality_projection = Idefics3SimpleMLP(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "modality_projection"),
|
||||
)
|
||||
|
||||
def pixel_shuffle(self,
|
||||
x: torch.Tensor,
|
||||
@ -431,9 +453,15 @@ class Idefics3Model(nn.Module):
|
||||
self.config = config
|
||||
self.padding_idx = self.config.text_config.pad_token_id
|
||||
self.vocab_size = self.config.text_config.vocab_size
|
||||
self.vision_model = Idefics3VisionTransformer(config.vision_config,
|
||||
quant_config)
|
||||
self.connector = Idefics3Connector(config)
|
||||
self.vision_model = Idefics3VisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"))
|
||||
self.connector = Idefics3Connector(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "connector"),
|
||||
)
|
||||
self.text_model = LlamaModel(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "text_model"),
|
||||
@ -637,6 +665,32 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
# vision_model
|
||||
".fc1.",
|
||||
".fc2.",
|
||||
".out_proj.",
|
||||
# connector
|
||||
".proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user