Fix Baichuan2-7B-Chat (#1987)
This commit is contained in:
parent
6ccc0bfffb
commit
2b981012a6
@ -366,12 +366,16 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM
|
||||
): # baichuan 13b, baichuan2 13b, baichuan2 7b
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", linear_method)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||
|
||||
Loading…
Reference in New Issue
Block a user