[torch.compile] Adding torch compile to vision-language models (#9946)

This commit is contained in:
Yongzao 2024-11-03 03:56:05 +08:00 committed by GitHub
parent 1b73ab2a1f
commit ae5279a163
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 8 deletions

View File

@ -606,7 +606,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
:class:`LlavaNextImageInputs` :class:`LlavaNextImageInputs`
""" """
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
else: else:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
@ -618,9 +617,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model.model.get_input_embeddings, self.language_model.model.get_input_embeddings,
lambda _: self._process_image_input(image_input), lambda _: self._process_image_input(image_input),
) )
input_ids = None
else: else:
inputs_embeds = None inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,

View File

@ -564,8 +564,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
output = self.llm( output = self.llm(
input_ids=None, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,

View File

@ -15,6 +15,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@ -713,6 +714,7 @@ class MolmoVisionBackbone(nn.Module):
return image_features return image_features
@support_torch_compile
class MolmoModel(nn.Module): class MolmoModel(nn.Module):
def __init__( def __init__(
@ -1141,7 +1143,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
else: else:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
@ -1156,10 +1157,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input["image_input_idx"], image_input["image_input_idx"],
image_input["seq_len"], image_input["seq_len"],
) )
input_ids = None
else: else:
inputs_embeds = None inputs_embeds = self.model.embed_tokens(input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,