diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7f5e3b96..e5bd58a9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -88,12 +88,13 @@ class MixtralMoE(nn.Module): tp_size=tp_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(num_tokens, hidden_size) + return final_hidden_states.view(orig_shape) class MixtralAttention(nn.Module): diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index ccaa6f20..7b18b5e0 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module): bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) shared_output = None if self.shared_expert is not None: @@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - return final_hidden_states.view(num_tokens, hidden_dim) + return final_hidden_states.view(orig_shape) class Qwen2MoeAttention(nn.Module):