[Misc] Update to comply with the new compressed-tensors config (#5350)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Dipika Sikka 2024-06-09 23:49:46 -04:00 committed by GitHub
parent 45f92c00cf
commit 5884c2b454
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 20 deletions

View File

@ -5,15 +5,15 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed"
with vllm_runner(model_path, quantization="sparseml",
enforce_eager=True) as llm:
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
@ -40,11 +40,17 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
assert qkv_proj.input_scale.dtype is torch.float32
def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
with vllm_runner(model_path) as llm:
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
with vllm_runner(model_path,
quantization="sparseml",
enforce_eager=True,
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
with vllm_runner(model_path, enforce_eager=True,
dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

View File

@ -164,12 +164,8 @@ class ModelConfig:
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# SparseML uses a "compression_config" with a "quantization_config".
compression_cfg = getattr(self.hf_config, "compression_config",
None)
if compression_cfg is not None:
quant_cfg = compression_cfg.get("quantization_config", None)
# compress-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg
def _verify_quantization(self) -> None:

View File

@ -31,7 +31,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"sparseml": CompressedTensorsConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
}

View File

@ -122,12 +122,9 @@ def get_quant_config(model_config: ModelConfig,
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if hf_quant_config is None:
compression_config = getattr(model_config.hf_config,
"compression_config", None)
if compression_config is not None:
hf_quant_config = compression_config.get("quantization_config",
None)
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config",
None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.