Change the load format to pt for Mixtral (#2028)

This commit is contained in:
Woosuk Kwon 2023-12-11 10:32:17 -08:00 committed by GitHub
parent 4ff0203987
commit b9bcdc7158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,6 +119,16 @@ class ModelConfig:
# Force ROCm to load from pt weights if nothing specific is set
if load_format == "auto":
load_format = "pt"
# FIXME(woosuk): This is a temporary hack. Support safetensor weights.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format != "pt":
logger.info(
"Currently, only 'pt' format is supported for Mixtral. "
"Changing the format to 'pt'. This may re-download the "
"weights if you have downloaded the safetensor weights.")
load_format = "pt"
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None: