[torch.compile] Adding torch compile annotations to some models (#9639)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4fdc581f9e
commit
8a02cd045a
@ -144,7 +144,7 @@ Text Generation
|
|||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`JAISLMHeadModel`
|
* - :code:`JAISLMHeadModel`
|
||||||
- Jais
|
- Jais
|
||||||
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
|
- :code:`inceptionai/jais-13b`, :code:`inceptionai/jais-13b-chat`, :code:`inceptionai/jais-30b-v3`, :code:`inceptionai/jais-30b-chat-v3`, etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`JambaForCausalLM`
|
* - :code:`JambaForCausalLM`
|
||||||
|
|||||||
@ -145,7 +145,7 @@ TEXT_GENERATION_MODELS = {
|
|||||||
# Uses Llama
|
# Uses Llama
|
||||||
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
|
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
|
||||||
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
|
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
|
||||||
"core42/jais-13b-chat": PPTestSettings.fast(),
|
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||||
# TODO: Implement PP
|
# TODO: Implement PP
|
||||||
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
|
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
|
||||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
|
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
|
||||||
# Copyright 2023 The vLLM team.
|
# Copyright 2023 The vLLM team.
|
||||||
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
|
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
|
||||||
# reserved.
|
# reserved.
|
||||||
@ -26,6 +26,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@ -212,6 +213,7 @@ class JAISBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class JAISModel(nn.Module):
|
class JAISModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -348,6 +349,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
return hidden_states, None
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class MiniCPMModel(nn.Module):
|
class MiniCPMModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
@ -204,6 +205,7 @@ class MPTBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class MPTModel(nn.Module):
|
class MPTModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -290,6 +291,7 @@ class NemotronDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class NemotronModel(nn.Module):
|
class NemotronModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers import OlmoConfig
|
from transformers import OlmoConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -221,6 +222,7 @@ class OlmoDecoderLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class OlmoModel(nn.Module):
|
class OlmoModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user