fix DbrxFusedNormAttention missing cache_config (#5340)

Co-authored-by: team <calvinn.ng@ahrefs.com>
This commit is contained in:
Calvinn Ng 2024-06-08 05:10:21 +08:00 committed by GitHub
parent 6840a71610
commit 767c727a81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -247,11 +247,12 @@ class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, quant_config)
self.attn = DbrxAttention(config, cache_config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)