[Misc] Support quantization of MllamaForCausalLM (#8822)

This commit is contained in:
Michael Goin 2024-09-25 17:46:22 -04:00 committed by GitHub
parent e2c6e0a829
commit 7193774b1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -624,6 +624,7 @@ class MllamaTextCrossAttention(nn.Module):
self,
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@ -648,12 +649,14 @@ class MllamaTextCrossAttention(nn.Module):
self.num_heads,
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
quant_config: Optional[QuantizationConfig]) \
-> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_idx=layer_idx,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
@ -725,6 +730,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -780,7 +786,8 @@ class MllamaTextModel(nn.Module):
for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(config, layer_idx))
MllamaCrossAttentionDecoderLayer(
config, layer_idx, quant_config=quant_config))
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(