[torch.compile] Adding torch compile annotations to some models (#9641)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Yongzao 2024-10-25 00:31:42 +08:00 committed by GitHub
parent de662d32b5
commit d27cfbf791
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 21 additions and 11 deletions

View File

@ -171,7 +171,8 @@ TEXT_GENERATION_MODELS = {
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
"bigcode/starcoder2-3b": PPTestSettings.fast(),
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
# FIXME: Cannot load tokenizer in latest transformers version
# FIXME: Cannot load tokenizer in latest transformers version.
# Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-only]
# TODO: Implement PP

View File

@ -24,6 +24,7 @@ from torch import nn
from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -279,6 +280,7 @@ class OPTDecoder(nn.Module):
return hidden_states
@support_torch_compile
class OPTModel(nn.Module):
def __init__(

View File

@ -11,6 +11,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
@ -184,7 +185,6 @@ class OrionDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
@ -203,9 +203,10 @@ class OrionDecoderLayer(nn.Module):
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, None
return hidden_states
@support_torch_compile
class OrionModel(nn.Module):
def __init__(
@ -233,8 +234,9 @@ class OrionModel(nn.Module):
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
make_empty_intermediate_tensors_factory([
"hidden_states",
], config.hidden_size))
def forward(
self,
@ -246,24 +248,20 @@ class OrionModel(nn.Module):
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states

View File

@ -27,6 +27,7 @@ from torch import nn
from transformers import PersimmonConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -209,6 +210,7 @@ class PersimmonDecoderLayer(nn.Module):
return outputs
@support_torch_compile
class PersimmonModel(nn.Module):
def __init__(self,

View File

@ -29,6 +29,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
@ -263,6 +264,7 @@ class SolarDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
class SolarModel(nn.Module):
def __init__(

View File

@ -25,6 +25,7 @@ from torch import nn
from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -193,6 +194,7 @@ class Starcoder2DecoderLayer(nn.Module):
return hidden_states
@support_torch_compile
class Starcoder2Model(nn.Module):
def __init__(self,

View File

@ -27,6 +27,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
@ -220,6 +221,7 @@ class XverseDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
class XverseModel(nn.Module):
def __init__(
@ -266,6 +268,7 @@ class XverseModel(nn.Module):
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(