[Model][LoRA]LoRA support added for idefics3 (#10281)

Signed-off-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
B-201 2024-11-13 17:25:59 +08:00 committed by GitHub
parent b6dde33019
commit d909acf9fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 10 deletions

View File

@ -450,7 +450,7 @@ Text Generation
- Idefics3
- T + I
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
-
- ✅︎
-
* - :code:`InternVLChatModel`
- InternVL2

View File

@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
@ -44,7 +45,7 @@ from vllm.utils import is_list_of
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal
from .interfaces import SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
@ -58,8 +59,6 @@ class Idefics3ImagePixelInputs(TypedDict):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
"""
rows: List[int]
cols: List[int]
pixel_attention_mask: Optional[torch.BoolTensor]
@ -356,8 +355,15 @@ def dummy_data_for_idefics3(
image_seq_len = processor.image_seq_len
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
if seq_len - max_llm_image_tokens < 0:
raise RuntimeError(
f"Idefics3 cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")
seq_data = SequenceData.from_prompt_token_counts(
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
(hf_config.image_token_id, max_llm_image_tokens),
(0, seq_len - max_llm_image_tokens))
width = height = hf_config.vision_config.image_size
image = Image.new("RGB", (width, height), color=0)
@ -463,8 +469,6 @@ class Idefics3Model(nn.Module):
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
rows = kwargs.pop("rows", None)
cols = kwargs.pop("cols", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
if pixel_values is None and image_embeds is None:
@ -489,8 +493,6 @@ class Idefics3Model(nn.Module):
data=self._validate_pixel_values(
flatten_bn(pixel_values,
concat=True)),
rows=rows,
cols=cols,
pixel_attention_mask=flatten_bn(
pixel_attention_mask,
concat=True))
@ -610,7 +612,33 @@ class Idefics3Model(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision_model
"fc1",
"fc2",
"out_proj",
# text_model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -672,3 +700,12 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model.text_model",
connector="model.connector",
tower_model="model.vision_model")