[Bugfix] Support 2D input shape in MoE layer (#6287)

This commit is contained in:
Woosuk Kwon 2024-07-10 06:03:16 -07:00 committed by GitHub
parent 8a924d2248
commit e72ae80b06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -88,12 +88,13 @@ class MixtralMoE(nn.Module):
tp_size=tp_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(num_tokens, hidden_size)
return final_hidden_states.view(orig_shape)
class MixtralAttention(nn.Module):

View File

@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
return final_hidden_states.view(orig_shape)
class Qwen2MoeAttention(nn.Module):