[torch.compile] Adding torch compile annotations to some models (#9876)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
93a76dd21d
commit
2b5bf20988
@ -281,7 +281,7 @@ Text Generation
|
|||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`Qwen2ForCausalLM`
|
* - :code:`Qwen2ForCausalLM`
|
||||||
- Qwen2
|
- Qwen2
|
||||||
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
|
- :code:`Qwen/Qwen2-7B-Instruct`, :code:`Qwen/Qwen2-7B`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`Qwen2MoeForCausalLM`
|
* - :code:`Qwen2MoeForCausalLM`
|
||||||
|
|||||||
@ -166,7 +166,7 @@ TEXT_GENERATION_MODELS = {
|
|||||||
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||||
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
||||||
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||||
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
|
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
|
||||||
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
|
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
|
||||||
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
|
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
|
||||||
"bigcode/starcoder2-3b": PPTestSettings.fast(),
|
"bigcode/starcoder2-3b": PPTestSettings.fast(),
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from torch.nn import LayerNorm
|
|||||||
from transformers import FalconConfig as HF_FalconConfig
|
from transformers import FalconConfig as HF_FalconConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -329,6 +330,7 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class FalconModel(nn.Module):
|
class FalconModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from torch import nn
|
|||||||
from transformers import PhiConfig
|
from transformers import PhiConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -193,6 +194,7 @@ class PhiLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class PhiModel(nn.Module):
|
class PhiModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from torchvision.transforms import InterpolationMode
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||||
@ -549,6 +550,7 @@ class QWenBlock(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class QWenModel(nn.Module):
|
class QWenModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from torch import nn
|
|||||||
from transformers import Qwen2Config
|
from transformers import Qwen2Config
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -237,6 +238,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class Qwen2Model(nn.Module):
|
class Qwen2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import (get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -312,6 +313,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class Qwen2MoeModel(nn.Module):
|
class Qwen2MoeModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user