[Bugfix] Support 2D input shape in MoE layer (#6287)
This commit is contained in:
parent
8a924d2248
commit
e72ae80b06
@ -88,12 +88,13 @@ class MixtralMoE(nn.Module):
|
|||||||
tp_size=tp_size)
|
tp_size=tp_size)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||||
return final_hidden_states.view(num_tokens, hidden_size)
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
class MixtralAttention(nn.Module):
|
class MixtralAttention(nn.Module):
|
||||||
|
|||||||
@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_dim = hidden_states.shape[-1]
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
shared_output = None
|
shared_output = None
|
||||||
if self.shared_expert is not None:
|
if self.shared_expert is not None:
|
||||||
@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeAttention(nn.Module):
|
class Qwen2MoeAttention(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user