[Model] Add support for IBM Granite Code models (#4636)
This commit is contained in:
parent
e254497b66
commit
6eaccb7353
@ -58,15 +58,16 @@ class LlamaMLP(nn.Module):
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: Optional[QKVParallelLinear] = None,
|
quant_config: Optional[QKVParallelLinear] = None,
|
||||||
|
bias: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2,
|
hidden_size, [intermediate_size] * 2,
|
||||||
bias=False,
|
bias=bias,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=bias,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
@ -209,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
bias=getattr(config, "mlp_bias", False),
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
@ -348,6 +350,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# compatibility
|
# compatibility
|
||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
)
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user