Fix Baichuan2-7B-Chat (#1987)

This commit is contained in:
firebook 2023-12-09 01:38:36 +08:00 committed by GitHub
parent 6ccc0bfffb
commit 2b981012a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -366,12 +366,16 @@ class BaiChuanBaseForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
class BaichuanForCausalLM(BaiChuanBaseForCausalLM
): # baichuan 13b, baichuan2 13b, baichuan2 7b
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ALIBI", linear_method)
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b