[Bugfix] Only add Attention.kv_scale if kv cache quantization is enabled (#5936)

This commit is contained in:
Michael Goin 2024-06-28 17:12:40 -04:00 committed by GitHub
parent be0b3af9e0
commit 4bf35ed9ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,7 @@ from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
class Attention(nn.Module): class Attention(nn.Module):
@ -56,15 +57,19 @@ class Attention(nn.Module):
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self) if quant_config else None self) if quant_config else None
if quant_method is not None: if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2": assert isinstance(quant_method, Fp8KVCacheMethod)
raise ValueError("fp8_e5m2 kv-cache is not supported with " # TODO (mgoin): kv cache dtype should be specified in the FP8
"fp8 checkpoints.") # checkpoint config and become the "auto" behavior
# When FP8 quantization is enabled, we make a parameter if "fp8" in self.kv_cache_dtype:
# "kv_scale" so that it can be loaded from FP8 checkpoint. if self.kv_cache_dtype == "fp8_e5m2":
# The kv_scale will then be converted back raise ValueError("fp8_e5m2 kv-cache is not supported with "
# to self._kv_scale in a native float32 value after weight loading. "fp8 checkpoints.")
self.quant_method = quant_method # When FP8 quantization is enabled, we make a parameter
self.quant_method.create_weights(self) # "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back to self._kv_scale
# in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.