[Bugfix] Only add Attention.kv_scale if kv cache quantization is enabled (#5936)
This commit is contained in:
parent
be0b3af9e0
commit
4bf35ed9ae
@ -9,6 +9,7 @@ from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -56,15 +57,19 @@ class Attention(nn.Module):
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self) if quant_config else None
|
||||
if quant_method is not None:
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# When FP8 quantization is enabled, we make a parameter
|
||||
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
||||
# The kv_scale will then be converted back
|
||||
# to self._kv_scale in a native float32 value after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
assert isinstance(quant_method, Fp8KVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if "fp8" in self.kv_cache_dtype:
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# When FP8 quantization is enabled, we make a parameter
|
||||
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
||||
# The kv_scale will then be converted back to self._kv_scale
|
||||
# in a native float32 value after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user