From e2afb03c92a06700d296a2e7f6565d4a4f05168c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 14 Jun 2024 22:28:11 +0200 Subject: [PATCH] [Bugfix] Enable loading FP8 checkpoints for gpt_bigcode models (#5460) Signed-off-by: Thomas Parnell --- vllm/model_executor/models/gpt_bigcode.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 69b75763..b15ed119 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -299,4 +299,10 @@ class GPTBigCodeForCausalLM(nn.Module): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, 'q') + weight_loader(param, loaded_weight, 'k') + weight_loader(param, loaded_weight, 'v') + else: + weight_loader(param, loaded_weight)