From 2b981012a6eb27d566f03cf61c06b1ef7a522f27 Mon Sep 17 00:00:00 2001 From: firebook Date: Sat, 9 Dec 2023 01:38:36 +0800 Subject: [PATCH] Fix Baichuan2-7B-Chat (#1987) --- vllm/model_executor/models/baichuan.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 3b56b9e1..f7a3b90a 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -366,12 +366,16 @@ class BaiChuanBaseForCausalLM(nn.Module): weight_loader(param, loaded_weight) -class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b +class BaichuanForCausalLM(BaiChuanBaseForCausalLM + ): # baichuan 13b, baichuan2 13b, baichuan2 7b def __init__(self, config, linear_method: Optional[LinearMethodBase] = None): - super().__init__(config, "ALIBI", linear_method) + if config.hidden_size == 4096: # baichuan2 7b + super().__init__(config, "ROPE", linear_method) + else: # baichuan 13b, baichuan2 13b + super().__init__(config, "ALIBI", linear_method) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b