[torch.compile] Adding torch compile to vision-language models (#9946)
This commit is contained in:
parent
1b73ab2a1f
commit
ae5279a163
@ -606,7 +606,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
:class:`LlavaNextImageInputs`
|
:class:`LlavaNextImageInputs`
|
||||||
"""
|
"""
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
input_ids = None
|
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
@ -618,9 +617,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.language_model.model.get_input_embeddings,
|
self.language_model.model.get_input_embeddings,
|
||||||
lambda _: self._process_image_input(image_input),
|
lambda _: self._process_image_input(image_input),
|
||||||
)
|
)
|
||||||
input_ids = None
|
|
||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
|
# always pass the input via `inputs_embeds`
|
||||||
|
# to make sure the computation graph is consistent
|
||||||
|
# for `torch.compile` integration
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
@ -564,8 +564,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
||||||
|
|
||||||
|
# always pass the input via `inputs_embeds`
|
||||||
|
# to make sure the computation graph is consistent
|
||||||
|
# for `torch.compile` integration
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
output = self.llm(
|
output = self.llm(
|
||||||
input_ids=None,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.selector import _Backend
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -713,6 +714,7 @@ class MolmoVisionBackbone(nn.Module):
|
|||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class MolmoModel(nn.Module):
|
class MolmoModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1141,7 +1143,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
input_ids = None
|
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
@ -1156,10 +1157,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_input["image_input_idx"],
|
image_input["image_input_idx"],
|
||||||
image_input["seq_len"],
|
image_input["seq_len"],
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = None
|
|
||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# always pass the input via `inputs_embeds`
|
||||||
|
# to make sure the computation graph is consistent
|
||||||
|
# for `torch.compile` integration
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user