diff --git a/vllm/config.py b/vllm/config.py index a2739e5f..6bafa73c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -119,6 +119,16 @@ class ModelConfig: # Force ROCm to load from pt weights if nothing specific is set if load_format == "auto": load_format = "pt" + + # FIXME(woosuk): This is a temporary hack. Support safetensor weights. + architectures = getattr(self.hf_config, "architectures", []) + if "MixtralForCausalLM" in architectures and load_format != "pt": + logger.info( + "Currently, only 'pt' format is supported for Mixtral. " + "Changing the format to 'pt'. This may re-download the " + "weights if you have downloaded the safetensor weights.") + load_format = "pt" + self.load_format = load_format def _verify_tokenizer_mode(self) -> None: