[Bugfix] Support 2D input shape in MoE layer (#6287)
This commit is contained in:
parent
8a924d2248
commit
e72ae80b06
@ -88,12 +88,13 @@ class MixtralMoE(nn.Module):
|
||||
tp_size=tp_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
return final_hidden_states.view(num_tokens, hidden_size)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
|
||||
@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class Qwen2MoeAttention(nn.Module):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user