[Bugfix] Fixup Mamba (#10004)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
bbc3619dc8
commit
ad23318928
@ -39,8 +39,8 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
||||||
mixer_rms_rps = config.mixer_rms_rps if self.is_falcon_mamba else None
|
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
|
||||||
self.mamba = MambaMixer(hidden_size=config.hidden_size,
|
self.mixer = MambaMixer(hidden_size=config.hidden_size,
|
||||||
ssm_state_size=config.state_size,
|
ssm_state_size=config.state_size,
|
||||||
conv_kernel_size=config.conv_kernel,
|
conv_kernel_size=config.conv_kernel,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
@ -48,7 +48,7 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
use_conv_bias=config.use_conv_bias,
|
use_conv_bias=config.use_conv_bias,
|
||||||
use_bias=config.use_bias,
|
use_bias=config.use_bias,
|
||||||
use_rms_norm=self.is_falcon_mamba,
|
use_rms_norm=self.is_falcon_mamba,
|
||||||
rms_norm_eps=mixer_rms_rps,
|
rms_norm_eps=mixer_rms_eps,
|
||||||
activation=config.hidden_act)
|
activation=config.hidden_act)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
@ -99,7 +99,6 @@ class MambaModel(nn.Module):
|
|||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
decoder_layers.append(
|
decoder_layers.append(
|
||||||
MambaDecoderLayer(config,
|
MambaDecoderLayer(config,
|
||||||
layer_idx=i,
|
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config))
|
quant_config=quant_config))
|
||||||
self.layers = nn.ModuleList(decoder_layers)
|
self.layers = nn.ModuleList(decoder_layers)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user