[Fix] Fix quantization="gptq" when using Marlin (#3319)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
602358f8a8
commit
b167109ba1
@ -168,13 +168,18 @@ class ModelConfig:
|
|||||||
# Parse quantization method from the HF model config, if available.
|
# Parse quantization method from the HF model config, if available.
|
||||||
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||||
if hf_quant_config is not None:
|
if hf_quant_config is not None:
|
||||||
|
|
||||||
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||||
|
|
||||||
# If the GPTQ model is serialized in marlin format, use marlin.
|
# If the GPTQ model is serialized in marlin format, use marlin.
|
||||||
if (hf_quant_method == "gptq"
|
if (hf_quant_method == "gptq"
|
||||||
and "is_marlin_format" in hf_quant_config
|
and "is_marlin_format" in hf_quant_config
|
||||||
and hf_quant_config["is_marlin_format"]):
|
and hf_quant_config["is_marlin_format"]):
|
||||||
|
logger.info("The model is serialized in Marlin format. "
|
||||||
|
"Using Marlin kernel.")
|
||||||
hf_quant_method = "marlin"
|
hf_quant_method = "marlin"
|
||||||
|
if self.quantization == "gptq":
|
||||||
|
self.quantization = hf_quant_method
|
||||||
|
|
||||||
if self.quantization is None:
|
if self.quantization is None:
|
||||||
self.quantization = hf_quant_method
|
self.quantization = hf_quant_method
|
||||||
elif self.quantization != hf_quant_method:
|
elif self.quantization != hf_quant_method:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user