[Bugfix] Fix PerTensorScaleParameter weight loading for fused models (#7376)
This commit is contained in:
parent
933790c209
commit
5c6c54d67a
@ -14,7 +14,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
PackedvLLMParameter)
|
PackedvLLMParameter,
|
||||||
|
PerTensorScaleParameter)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -573,11 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
param: BasevLLMParameter,
|
param: BasevLLMParameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: Optional[int] = None):
|
loaded_shard_id: Optional[int] = None):
|
||||||
param_data = param.data
|
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
if param.output_dim is None:
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
assert param_data.shape == loaded_weight.shape
|
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||||
param_data.copy_(loaded_weight)
|
shard_id=0)
|
||||||
|
return
|
||||||
|
elif type(param) is BasevLLMParameter:
|
||||||
|
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||||
return
|
return
|
||||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||||
return
|
return
|
||||||
@ -720,11 +723,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
param: BasevLLMParameter,
|
param: BasevLLMParameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: Optional[str] = None):
|
loaded_shard_id: Optional[str] = None):
|
||||||
param_data = param.data
|
|
||||||
if loaded_shard_id is None: # special case for certain models
|
if loaded_shard_id is None: # special case for certain models
|
||||||
if param.output_dim is None:
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
assert param_data.shape == loaded_weight.shape
|
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||||
param_data.copy_(loaded_weight)
|
shard_id=0)
|
||||||
|
return
|
||||||
|
elif type(param) is BasevLLMParameter:
|
||||||
|
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||||
return
|
return
|
||||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user