[Model] Add starcoder2 awq support (#3569)

This commit is contained in:
少年 2024-03-25 12:07:36 +08:00 committed by GitHub
parent 56a8652f33
commit b0dfa91dd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -141,8 +141,9 @@ class Starcoder2MLP(nn.Module):
bias=config.use_bias,
linear_method=linear_method,
)
self.act = get_act_fn(config.hidden_act,
intermediate_size=config.intermediate_size)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)