diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d0427fb9..2f6cdbc6 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -579,6 +579,10 @@ class ShardedStateLoader(BaseModelLoader): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, cache_config) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path,