[torch.compile] Adding torch compile to vision-language models (#9946)
This commit is contained in:
parent
1b73ab2a1f
commit
ae5279a163
@ -606,7 +606,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
:class:`LlavaNextImageInputs`
|
||||
"""
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@ -618,9 +617,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.language_model.model.get_input_embeddings,
|
||||
lambda _: self._process_image_input(image_input),
|
||||
)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
@ -564,8 +564,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
output = self.llm(
|
||||
input_ids=None,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
|
||||
@ -15,6 +15,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -713,6 +714,7 @@ class MolmoVisionBackbone(nn.Module):
|
||||
return image_features
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MolmoModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -1141,7 +1143,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@ -1156,10 +1157,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_input["image_input_idx"],
|
||||
image_input["seq_len"],
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user