[Misc] Support quantization of MllamaForCausalLM (#8822)

This commit is contained in:
Michael Goin 2024-09-25 17:46:22 -04:00 committed by GitHub
parent e2c6e0a829
commit 7193774b1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -624,6 +624,7 @@ class MllamaTextCrossAttention(nn.Module):
self, self,
config: Optional[config_mllama.MllamaTextConfig] = None, config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -648,12 +649,14 @@ class MllamaTextCrossAttention(nn.Module):
self.num_heads, self.num_heads,
self.num_key_value_heads, self.num_key_value_heads,
bias=False, bias=False,
quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config,
) )
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead # use huggingface's instead
@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention """Cross-attention transformer block with tanh-gated attention
and feedforward.""" and feedforward."""
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
quant_config: Optional[QuantizationConfig]) \
-> None: -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention( self.cross_attn = MllamaTextCrossAttention(
config=config, config=config,
layer_idx=layer_idx, layer_idx=layer_idx,
quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
@ -725,6 +730,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config,
) )
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -780,7 +786,8 @@ class MllamaTextModel(nn.Module):
for layer_idx in range(config.num_hidden_layers): for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers: if layer_idx in self.cross_attention_layers:
layers.append( layers.append(
MllamaCrossAttentionDecoderLayer(config, layer_idx)) MllamaCrossAttentionDecoderLayer(
config, layer_idx, quant_config=quant_config))
else: else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False # TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append( layers.append(