Support for Stable LM 2 (#2598)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
6b7de1a030
commit
3a0e1fc070
@ -98,7 +98,7 @@ class StablelmAttention(nn.Module):
|
|||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_key_value_heads * self.head_dim
|
self.kv_size = self.num_key_value_heads * self.head_dim
|
||||||
|
self.qkv_bias = getattr(config, "use_qkv_bias", False)
|
||||||
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
|
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
@ -108,7 +108,7 @@ class StablelmAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_key_value_heads,
|
self.total_num_key_value_heads,
|
||||||
bias=False,
|
self.qkv_bias,
|
||||||
linear_method=linear_method)
|
linear_method=linear_method)
|
||||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user