[Bugfix] Fixup Mamba (#10004)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2024-11-04 22:46:38 -05:00 committed by GitHub
parent bbc3619dc8
commit ad23318928
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,8 +39,8 @@ class MambaDecoderLayer(nn.Module):
super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
mixer_rms_rps = config.mixer_rms_rps if self.is_falcon_mamba else None
self.mamba = MambaMixer(hidden_size=config.hidden_size,
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
conv_kernel_size=config.conv_kernel,
intermediate_size=config.intermediate_size,
@ -48,7 +48,7 @@ class MambaDecoderLayer(nn.Module):
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
use_rms_norm=self.is_falcon_mamba,
rms_norm_eps=mixer_rms_rps,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -99,7 +99,6 @@ class MambaModel(nn.Module):
for i in range(config.num_hidden_layers):
decoder_layers.append(
MambaDecoderLayer(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)