[Bugfix] Fixup Mamba (#10004)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
bbc3619dc8
commit
ad23318928
@ -39,8 +39,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
||||
mixer_rms_rps = config.mixer_rms_rps if self.is_falcon_mamba else None
|
||||
self.mamba = MambaMixer(hidden_size=config.hidden_size,
|
||||
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
|
||||
self.mixer = MambaMixer(hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.state_size,
|
||||
conv_kernel_size=config.conv_kernel,
|
||||
intermediate_size=config.intermediate_size,
|
||||
@ -48,7 +48,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.use_bias,
|
||||
use_rms_norm=self.is_falcon_mamba,
|
||||
rms_norm_eps=mixer_rms_rps,
|
||||
rms_norm_eps=mixer_rms_eps,
|
||||
activation=config.hidden_act)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@ -99,7 +99,6 @@ class MambaModel(nn.Module):
|
||||
for i in range(config.num_hidden_layers):
|
||||
decoder_layers.append(
|
||||
MambaDecoderLayer(config,
|
||||
layer_idx=i,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
self.layers = nn.ModuleList(decoder_layers)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user