[torch.compile] Adding torch compile annotations to some models (#9641)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
de662d32b5
commit
d27cfbf791
@ -171,7 +171,8 @@ TEXT_GENERATION_MODELS = {
|
||||
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
|
||||
"bigcode/starcoder2-3b": PPTestSettings.fast(),
|
||||
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
|
||||
# FIXME: Cannot load tokenizer in latest transformers version
|
||||
# FIXME: Cannot load tokenizer in latest transformers version.
|
||||
# Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
|
||||
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
# [Encoder-only]
|
||||
# TODO: Implement PP
|
||||
|
||||
@ -24,6 +24,7 @@ from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -279,6 +280,7 @@ class OPTDecoder(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class OPTModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -11,6 +11,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -184,7 +185,6 @@ class OrionDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -203,9 +203,10 @@ class OrionDecoderLayer(nn.Module):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states, None
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class OrionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -233,8 +234,9 @@ class OrionModel(nn.Module):
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
make_empty_intermediate_tensors_factory([
|
||||
"hidden_states",
|
||||
], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -246,24 +248,20 @@ class OrionModel(nn.Module):
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -27,6 +27,7 @@ from torch import nn
|
||||
from transformers import PersimmonConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -209,6 +210,7 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class PersimmonModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -29,6 +29,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
@ -263,6 +264,7 @@ class SolarDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SolarModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -25,6 +25,7 @@ from torch import nn
|
||||
from transformers import Starcoder2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -193,6 +194,7 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Starcoder2Model(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -27,6 +27,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -220,6 +221,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class XverseModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -266,6 +268,7 @@ class XverseModel(nn.Module):
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user