Change the load format to pt for Mixtral (#2028)
This commit is contained in:
parent
4ff0203987
commit
b9bcdc7158
@ -119,6 +119,16 @@ class ModelConfig:
|
|||||||
# Force ROCm to load from pt weights if nothing specific is set
|
# Force ROCm to load from pt weights if nothing specific is set
|
||||||
if load_format == "auto":
|
if load_format == "auto":
|
||||||
load_format = "pt"
|
load_format = "pt"
|
||||||
|
|
||||||
|
# FIXME(woosuk): This is a temporary hack. Support safetensor weights.
|
||||||
|
architectures = getattr(self.hf_config, "architectures", [])
|
||||||
|
if "MixtralForCausalLM" in architectures and load_format != "pt":
|
||||||
|
logger.info(
|
||||||
|
"Currently, only 'pt' format is supported for Mixtral. "
|
||||||
|
"Changing the format to 'pt'. This may re-download the "
|
||||||
|
"weights if you have downloaded the safetensor weights.")
|
||||||
|
load_format = "pt"
|
||||||
|
|
||||||
self.load_format = load_format
|
self.load_format = load_format
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user