[Model][LoRA]LoRA support added for idefics3 (#10281)
Signed-off-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
parent
b6dde33019
commit
d909acf9fe
@ -450,7 +450,7 @@ Text Generation
|
|||||||
- Idefics3
|
- Idefics3
|
||||||
- T + I
|
- T + I
|
||||||
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
|
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
|
||||||
-
|
- ✅︎
|
||||||
-
|
-
|
||||||
* - :code:`InternVLChatModel`
|
* - :code:`InternVLChatModel`
|
||||||
- InternVL2
|
- InternVL2
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.image import cached_get_image_processor
|
from vllm.multimodal.image import cached_get_image_processor
|
||||||
@ -44,7 +45,7 @@ from vllm.utils import is_list_of
|
|||||||
from .idefics2_vision_model import (
|
from .idefics2_vision_model import (
|
||||||
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||||
from .llama import LlamaModel
|
from .llama import LlamaModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
@ -58,8 +59,6 @@ class Idefics3ImagePixelInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||||
"""
|
"""
|
||||||
rows: List[int]
|
|
||||||
cols: List[int]
|
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor]
|
pixel_attention_mask: Optional[torch.BoolTensor]
|
||||||
|
|
||||||
|
|
||||||
@ -356,8 +355,15 @@ def dummy_data_for_idefics3(
|
|||||||
image_seq_len = processor.image_seq_len
|
image_seq_len = processor.image_seq_len
|
||||||
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
|
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
|
||||||
|
|
||||||
|
if seq_len - max_llm_image_tokens < 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Idefics3 cannot process {num_images} images in a prompt, "
|
||||||
|
"please increase max_model_len or reduce image limit by "
|
||||||
|
"--limit-mm-per-prompt.")
|
||||||
|
|
||||||
seq_data = SequenceData.from_prompt_token_counts(
|
seq_data = SequenceData.from_prompt_token_counts(
|
||||||
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
|
(hf_config.image_token_id, max_llm_image_tokens),
|
||||||
|
(0, seq_len - max_llm_image_tokens))
|
||||||
|
|
||||||
width = height = hf_config.vision_config.image_size
|
width = height = hf_config.vision_config.image_size
|
||||||
image = Image.new("RGB", (width, height), color=0)
|
image = Image.new("RGB", (width, height), color=0)
|
||||||
@ -463,8 +469,6 @@ class Idefics3Model(nn.Module):
|
|||||||
self, **kwargs: object) -> Optional[ImageInputs]:
|
self, **kwargs: object) -> Optional[ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
rows = kwargs.pop("rows", None)
|
|
||||||
cols = kwargs.pop("cols", None)
|
|
||||||
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
|
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
|
||||||
|
|
||||||
if pixel_values is None and image_embeds is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
@ -489,8 +493,6 @@ class Idefics3Model(nn.Module):
|
|||||||
data=self._validate_pixel_values(
|
data=self._validate_pixel_values(
|
||||||
flatten_bn(pixel_values,
|
flatten_bn(pixel_values,
|
||||||
concat=True)),
|
concat=True)),
|
||||||
rows=rows,
|
|
||||||
cols=cols,
|
|
||||||
pixel_attention_mask=flatten_bn(
|
pixel_attention_mask=flatten_bn(
|
||||||
pixel_attention_mask,
|
pixel_attention_mask,
|
||||||
concat=True))
|
concat=True))
|
||||||
@ -610,7 +612,33 @@ class Idefics3Model(nn.Module):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
|
||||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
|
SupportsLoRA):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
# vision_model
|
||||||
|
"fc1",
|
||||||
|
"fc2",
|
||||||
|
"out_proj",
|
||||||
|
# text_model
|
||||||
|
"qkv_proj", # same name with vision encoder
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -672,3 +700,12 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
loader.load_weights(weights)
|
loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""
|
||||||
|
Get the module prefix in multimodal models
|
||||||
|
"""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="model.text_model",
|
||||||
|
connector="model.connector",
|
||||||
|
tower_model="model.vision_model")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user