From e72ae80b06405ea92b703c8979f046d68e970c94 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 Jul 2024 06:03:16 -0700 Subject: [PATCH] [Bugfix] Support 2D input shape in MoE layer (#6287) --- vllm/model_executor/models/mixtral.py | 5 +++-- vllm/model_executor/models/qwen2_moe.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7f5e3b96..e5bd58a9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -88,12 +88,13 @@ class MixtralMoE(nn.Module): tp_size=tp_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(num_tokens, hidden_size) + return final_hidden_states.view(orig_shape) class MixtralAttention(nn.Module): diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index ccaa6f20..7b18b5e0 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module): bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) shared_output = None if self.shared_expert is not None: @@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - return final_hidden_states.view(num_tokens, hidden_dim) + return final_hidden_states.view(orig_shape) class Qwen2MoeAttention(nn.Module):