[Misc] Refactor linear layer weight loading; introduce BasevLLMParameter and weight_loader_v2 (#5874)

This commit is contained in:
Dipika Sikka 2024-08-07 12:17:58 -04:00 committed by GitHub
parent 639159b2a6
commit 0f7052bc7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 653 additions and 201 deletions

View File

@ -9,7 +9,7 @@ import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)
@ -109,7 +109,7 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == pack_factor
assert qkv_proj.scheme.pack_factor == pack_factor
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@ -140,13 +140,17 @@ def test_compressed_tensors_fp8(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight_scale.dtype is torch.float32
# should be scalars after processing
assert len(qkv_proj.input_scale.shape) == 0
assert len(qkv_proj.weight_scale.shape) == 0
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -1,7 +1,11 @@
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
__all__ = [
"SamplingMetadata",
"set_random_seed",
"BasevLLMParameter",
"PackedvLLMParameter",
]

View File

@ -13,10 +13,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
@ -288,6 +292,7 @@ class ColumnParallelLinear(LinearBase):
if output_sizes is None:
output_sizes = [output_size]
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
@ -295,7 +300,9 @@ class ColumnParallelLinear(LinearBase):
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if bias:
self.bias = Parameter(
@ -337,6 +344,9 @@ class ColumnParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
@ -527,6 +537,62 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
current_shard_offset = 0
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
if loaded_shard_id is None:
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
@ -598,6 +664,82 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config=quant_config,
prefix=prefix)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
if loaded_shard_id is None: # special case for certain models
if param.output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id in ["q", "k", "v"]
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
param.load_qkv_weight(loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
@ -798,6 +940,7 @@ class RowParallelLinear(LinearBase):
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
@ -805,7 +948,9 @@ class RowParallelLinear(LinearBase):
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
@ -850,6 +995,10 @@ class RowParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):
if self.input_is_parallel:
input_parallel = input_

View File

@ -19,6 +19,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
__all__ = ["CompressedTensorsLinearMethod"]
class CompressedTensorsConfig(QuantizationConfig):
@ -146,18 +148,15 @@ class CompressedTensorsConfig(QuantizationConfig):
if weight_quant is None or input_quant is None:
return False
# Confirm we have floating points.
if not (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT):
return False
# Confirm weight scheme is supported.
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False
@ -169,11 +168,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_activation and is_per_tensor_activation):
return False
# All conditions satisfied.
return True
return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
@ -230,6 +225,7 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size=weight_quant.group_size)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
@ -237,7 +233,8 @@ class CompressedTensorsConfig(QuantizationConfig):
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,

View File

@ -2,11 +2,10 @@ from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import ModelWeightParameter
__all__ = ["CompressedTensorsUnquantized"]
@ -24,7 +23,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
# required by torch.compile to be torch.nn.Parameter
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
@ -32,14 +33,15 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

View File

@ -8,7 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"]
@ -45,7 +48,12 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
# required by torch.compile to be torch.nn.Parameter
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
layer.scale_packed = Parameter(layer.scale_packed.data,
requires_grad=False)
layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
@ -56,79 +64,65 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": pack_factor,
"marlin_tile_size": self.tile_size,
"weight_loader": weight_loader
},
)
layer.register_parameter("weight_packed", qweight)
qweight = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=pack_factor,
marlin_tile_size=self.tile_size,
weight_loader=weight_loader)
input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)
scales = Parameter(
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"output_dim": 1,
"input_dim": None if input_groups == 1 else 0,
"weight_loader": weight_loader
},
)
layer.register_parameter("scale_packed", scales)
"weight_loader":
weight_loader
}
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
requires_grad=False)
if self.group_size is not None:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
else:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", qweight)
layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
"weight_loader": weight_loader
},
)
layer.register_parameter("scale_packed", scales)
layer.register_parameter("meta", meta)
max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
requires_grad=False)
layer.workspace = workspace

View File

@ -9,9 +9,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, create_per_channel_scale_param,
create_per_tensor_scale_param)
from vllm.model_executor.utils import set_weight_attrs
convert_to_channelwise)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A16Fp8"]
@ -40,11 +41,19 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False)
else:
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False)
if self.is_static_input_scheme:
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel")
def create_weights(self, layer: torch.nn.Module, input_size: int,
@ -60,35 +69,39 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.orig_dtype = params_dtype
# WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}")
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,

View File

@ -8,10 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param,
create_per_tensor_scale_param, cutlass_fp8_supported,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A8Fp8"]
@ -46,6 +46,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
@ -66,32 +69,40 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,

View File

@ -8,9 +8,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param,
create_per_tensor_scale_param)
from vllm.model_executor.utils import set_weight_attrs
apply_int8_linear, convert_to_channelwise)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
@ -39,7 +41,9 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
ws_channelwise = convert_to_channelwise(layer.weight_scale,
self.logical_widths)
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
else:
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
@ -55,32 +59,35 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.logical_widths = output_partition_sizes
# WEIGHT
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,

View File

@ -1,7 +1,6 @@
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
@ -10,7 +9,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"]
@ -30,17 +32,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.group_size: int
if group_size is None:
if self.strategy != "channel":
raise ValueError(
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise.")
self.group_size = -1
else:
self.group_size = group_size
if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise.")
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
@ -63,11 +60,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
@ -79,60 +77,51 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
input_size=input_size,
group_size=group_size)
weight_scale_dim = None
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size
weight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": self.pack_factor,
"weight_loader": weight_loader
})
layer.register_parameter("weight_packed", weight)
weight_scale = Parameter(
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"input_dim": weight_scale_dim,
"output_dim": 0
})
layer.register_parameter("weight_scale", weight_scale)
)
}
if self.group_size == -1:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
requires_grad=False)
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
@ -154,10 +143,15 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
# Update for kernel
layer.weight_packed = torch.nn.Parameter(
layer.weight_packed.t().contiguous(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed.t().contiguous(),
layer.weight_packed,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
@ -166,7 +160,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
# Permute scales from compressed-tensors format to marlin format.
marlin_scales = marlin_permute_scales(
layer.weight_scale.squeeze().t().contiguous(),
layer.weight_scale,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=layer.group_size)

View File

@ -0,0 +1,277 @@
from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter"
]
logger = init_logger(__name__)
class BasevLLMParameter(Parameter):
"""
Base parameter for vLLM linear layers. Extends the torch.nn.parameter
by taking in a linear weight loader. Will copy the loaded weight
into the parameter when the provided weight loader is called.
"""
def __new__(cls, data: torch.Tensor, **kwargs):
return super().__new__(cls, data=data, requires_grad=False)
def __init__(self, data: torch.Tensor, weight_loader: Callable):
"""
Initialize the BasevLLMParameter
:param data: torch tensor with the parameter data
:param weight_loader: weight loader callable
:returns: a torch.nn.parameter
"""
self._weight_loader = weight_loader
@property
def weight_loader(self):
return self._weight_loader
def _assert_and_load(self, loaded_weight: torch.Tensor):
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
self._assert_and_load(loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
self._assert_and_load(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
self._assert_and_load(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
self._assert_and_load(loaded_weight)
class _ColumnvLLMParameter(BasevLLMParameter):
"""
Private class defining weight loading functionality
(load_merged_column_weight, load_qkv_weight)
for parameters being loaded into linear layers with column
parallelism. This includes QKV and MLP layers which are
not already fused on disk. Requires an output dimension
to be defined. Called within the weight loader of
each of the column parallel linear layers.
"""
def __init__(self, output_dim: int, **kwargs):
self._output_dim = output_dim
super().__init__(**kwargs)
@property
def output_dim(self):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
if isinstance(
self,
PackedvLLMParameter) and self.packed_dim == self.output_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads")
if isinstance(
self,
PackedvLLMParameter) and self.output_dim == self.packed_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
shard_id * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class ModelWeightParameter(_ColumnvLLMParameter):
"""
Parameter class for linear layer weights. Extends the
_ColumnvLLMParameter by adding loading functionality
for linear layers with row parallel functionality.
Requires an input dimension to be defined.
"""
def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim
super().__init__(**kwargs)
@property
def input_dim(self):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
class GroupQuantScaleParameter(ModelWeightParameter):
"""
Parameter class for weight scales loaded for weights with
grouped quantization. Equivalent to ModelWeightParameter.
"""
pass
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
"""
pass
class PerTensorScaleParameter(BasevLLMParameter):
"""
Parameter class for scales where the number of scales is
equivalent to the number of logical matrices in fused linear
layers (e.g. for QKV, there are 3 scales loaded from disk).
This is relevant to weights with per-tensor quantization.
Adds functionality to map the scalers to a shard during
weight loading.
Note: additional parameter manipulation may be handled
for each quantization config specifically, within
process_weights_after_loading
"""
def __init__(self, **kwargs):
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
super().__init__(**kwargs)
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]
def load_merged_column_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
def load_qkv_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):
"""
Slice the parameter data based on the shard id for
loading.
"""
param_data = self.data
shard_id = self._shard_id_as_int(shard_id)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]
param_data = param_data[shard_id]
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class PackedvLLMParameter(ModelWeightParameter):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
"""
def __init__(self,
packed_factor: int,
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile = marlin_tile_size
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
@property
def marlin_tile(self):
return self._marlin_tile
def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
shard_size = shard_size // self.packed_factor
shard_offset = shard_offset // self.packed_factor
if self.marlin_tile is not None:
return self._adjust_shard_indexes_for_marlin(
shard_size, shard_offset)
return shard_size, shard_offset