[Misc] Support quantization of MllamaForCausalLM (#8822)
This commit is contained in:
parent
e2c6e0a829
commit
7193774b1f
@ -624,6 +624,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self,
|
||||
config: Optional[config_mllama.MllamaTextConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -648,12 +649,14 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
||||
# use huggingface's instead
|
||||
@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
"""Cross-attention transformer block with tanh-gated attention
|
||||
and feedforward."""
|
||||
|
||||
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
|
||||
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig]) \
|
||||
-> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.cross_attn = MllamaTextCrossAttention(
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -725,6 +730,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -780,7 +786,8 @@ class MllamaTextModel(nn.Module):
|
||||
for layer_idx in range(config.num_hidden_layers):
|
||||
if layer_idx in self.cross_attention_layers:
|
||||
layers.append(
|
||||
MllamaCrossAttentionDecoderLayer(config, layer_idx))
|
||||
MllamaCrossAttentionDecoderLayer(
|
||||
config, layer_idx, quant_config=quant_config))
|
||||
else:
|
||||
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
||||
layers.append(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user