[Bugfix] Only add Attention.kv_scale if kv cache quantization is enabled (#5936)
This commit is contained in:
parent
be0b3af9e0
commit
4bf35ed9ae
@ -9,6 +9,7 @@ from vllm.attention.selector import get_attn_backend
|
|||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -56,13 +57,17 @@ class Attention(nn.Module):
|
|||||||
quant_method = quant_config.get_quant_method(
|
quant_method = quant_config.get_quant_method(
|
||||||
self) if quant_config else None
|
self) if quant_config else None
|
||||||
if quant_method is not None:
|
if quant_method is not None:
|
||||||
|
assert isinstance(quant_method, Fp8KVCacheMethod)
|
||||||
|
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||||
|
# checkpoint config and become the "auto" behavior
|
||||||
|
if "fp8" in self.kv_cache_dtype:
|
||||||
if self.kv_cache_dtype == "fp8_e5m2":
|
if self.kv_cache_dtype == "fp8_e5m2":
|
||||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||||
"fp8 checkpoints.")
|
"fp8 checkpoints.")
|
||||||
# When FP8 quantization is enabled, we make a parameter
|
# When FP8 quantization is enabled, we make a parameter
|
||||||
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
||||||
# The kv_scale will then be converted back
|
# The kv_scale will then be converted back to self._kv_scale
|
||||||
# to self._kv_scale in a native float32 value after weight loading.
|
# in a native float32 value after weight loading.
|
||||||
self.quant_method = quant_method
|
self.quant_method = quant_method
|
||||||
self.quant_method.create_weights(self)
|
self.quant_method.create_weights(self)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user