[Model][LoRA]LoRA support added for idefics3 (#10281)
Signed-off-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
parent
b6dde33019
commit
d909acf9fe
@ -450,7 +450,7 @@ Text Generation
|
||||
- Idefics3
|
||||
- T + I
|
||||
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
|
||||
-
|
||||
- ✅︎
|
||||
-
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
|
||||
@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
@ -44,7 +45,7 @@ from vllm.utils import is_list_of
|
||||
from .idefics2_vision_model import (
|
||||
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
||||
# yapf: enable
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@ -58,8 +59,6 @@ class Idefics3ImagePixelInputs(TypedDict):
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
"""
|
||||
rows: List[int]
|
||||
cols: List[int]
|
||||
pixel_attention_mask: Optional[torch.BoolTensor]
|
||||
|
||||
|
||||
@ -356,8 +355,15 @@ def dummy_data_for_idefics3(
|
||||
image_seq_len = processor.image_seq_len
|
||||
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
|
||||
|
||||
if seq_len - max_llm_image_tokens < 0:
|
||||
raise RuntimeError(
|
||||
f"Idefics3 cannot process {num_images} images in a prompt, "
|
||||
"please increase max_model_len or reduce image limit by "
|
||||
"--limit-mm-per-prompt.")
|
||||
|
||||
seq_data = SequenceData.from_prompt_token_counts(
|
||||
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
|
||||
(hf_config.image_token_id, max_llm_image_tokens),
|
||||
(0, seq_len - max_llm_image_tokens))
|
||||
|
||||
width = height = hf_config.vision_config.image_size
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
@ -463,8 +469,6 @@ class Idefics3Model(nn.Module):
|
||||
self, **kwargs: object) -> Optional[ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
rows = kwargs.pop("rows", None)
|
||||
cols = kwargs.pop("cols", None)
|
||||
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
@ -489,8 +493,6 @@ class Idefics3Model(nn.Module):
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values,
|
||||
concat=True)),
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
pixel_attention_mask=flatten_bn(
|
||||
pixel_attention_mask,
|
||||
concat=True))
|
||||
@ -610,7 +612,33 @@ class Idefics3Model(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
|
||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# vision_model
|
||||
"fc1",
|
||||
"fc2",
|
||||
"out_proj",
|
||||
# text_model
|
||||
"qkv_proj", # same name with vision encoder
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -672,3 +700,12 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="model.text_model",
|
||||
connector="model.connector",
|
||||
tower_model="model.vision_model")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user