[Misc] Support quantization of MllamaForCausalLM (#8822)
This commit is contained in:
parent
e2c6e0a829
commit
7193774b1f
@ -624,6 +624,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: Optional[config_mllama.MllamaTextConfig] = None,
|
config: Optional[config_mllama.MllamaTextConfig] = None,
|
||||||
layer_idx: Optional[int] = None,
|
layer_idx: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -648,12 +649,14 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.num_key_value_heads,
|
self.num_key_value_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.num_heads * self.head_dim,
|
self.num_heads * self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
||||||
# use huggingface's instead
|
# use huggingface's instead
|
||||||
@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||||||
"""Cross-attention transformer block with tanh-gated attention
|
"""Cross-attention transformer block with tanh-gated attention
|
||||||
and feedforward."""
|
and feedforward."""
|
||||||
|
|
||||||
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
|
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
|
||||||
|
quant_config: Optional[QuantizationConfig]) \
|
||||||
-> None:
|
-> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.cross_attn = MllamaTextCrossAttention(
|
self.cross_attn = MllamaTextCrossAttention(
|
||||||
config=config,
|
config=config,
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
@ -725,6 +730,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
@ -780,7 +786,8 @@ class MllamaTextModel(nn.Module):
|
|||||||
for layer_idx in range(config.num_hidden_layers):
|
for layer_idx in range(config.num_hidden_layers):
|
||||||
if layer_idx in self.cross_attention_layers:
|
if layer_idx in self.cross_attention_layers:
|
||||||
layers.append(
|
layers.append(
|
||||||
MllamaCrossAttentionDecoderLayer(config, layer_idx))
|
MllamaCrossAttentionDecoderLayer(
|
||||||
|
config, layer_idx, quant_config=quant_config))
|
||||||
else:
|
else:
|
||||||
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
||||||
layers.append(
|
layers.append(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user