[Model] Correct Mixtral FP8 checkpoint loading (#5231)
This commit is contained in:
parent
ccd4f129e8
commit
5563a4dea8
@ -300,14 +300,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor,
|
||||
inv_scale: float) -> torch.Tensor:
|
||||
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return qweight.to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def per_tensor_dequantize(tensor: torch.Tensor,
|
||||
inv_scale: float) -> torch.Tensor:
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float,
|
||||
torch.Tensor]) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
@ -41,7 +41,9 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
per_tensor_dequantize,
|
||||
per_tensor_quantize)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -98,16 +100,16 @@ class MixtralMoE(nn.Module):
|
||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.w13_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
dtype=params_dtype))
|
||||
self.w2_weight = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
dtype=params_dtype))
|
||||
self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
|
||||
set_weight_attrs(self.w13_weight, {
|
||||
"weight_loader": self.weight_loader,
|
||||
@ -124,7 +126,10 @@ class MixtralMoE(nn.Module):
|
||||
|
||||
if self.use_fp8:
|
||||
# WEIGHT_SCALE (for fp8)
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
@ -148,11 +153,11 @@ class MixtralMoE(nn.Module):
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
self.a13_scale = nn.Parameter(torch.zeros(
|
||||
self.a13_scale = nn.Parameter(torch.ones(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(torch.zeros(
|
||||
self.num_total_experts, dtype=torch.float32),
|
||||
self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
set_weight_attrs(self.a13_scale, {
|
||||
@ -175,8 +180,22 @@ class MixtralMoE(nn.Module):
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||
|
||||
# Loading scales
|
||||
if "act_scale" in weight_name or "w2.weight_scale" in weight_name:
|
||||
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"act_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param_data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
param_data[expert_id] = loaded_weight
|
||||
elif "weight_scale" in weight_name:
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
assert "w1" in weight_name or "w3" in weight_name
|
||||
shard_id = 0 if "w1" in weight_name else 1
|
||||
param_data[expert_id][shard_id] = loaded_weight
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
# Fp8 is the only case where we need to process after loading.
|
||||
@ -189,6 +208,12 @@ class MixtralMoE(nn.Module):
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_weight = torch.empty_like(self.w2_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
for expert in range(self.num_total_experts):
|
||||
w13_weight[expert, :, :], self.w13_scale[
|
||||
expert] = ops.scaled_fp8_quant(
|
||||
@ -199,25 +224,44 @@ class MixtralMoE(nn.Module):
|
||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
# If checkpoint is fp8 + static, cleanup act_scales.
|
||||
# Since state_dict has an act_scale per expert but our kernels
|
||||
# are passed one act_scale shared across all experts.
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
if self.a13_scale is None or self.a2_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
else:
|
||||
# If checkpoint is fp8 + static, cleanup act_scales.
|
||||
# Since state_dict has an act_scale per expert but our kernels
|
||||
# are passed one act_scale shared across all experts.
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if self.a13_scale is None or self.a2_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
|
||||
if (not all_close_1d(self.a13_scale)
|
||||
or not all_close_1d(self.a2_scale)):
|
||||
print_warning_once(
|
||||
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||
"Using the maximum across experts for each layer. ")
|
||||
if (not all_close_1d(self.a13_scale)
|
||||
or not all_close_1d(self.a2_scale)):
|
||||
print_warning_once(
|
||||
"Found act_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. ")
|
||||
|
||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
||||
requires_grad=False)
|
||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
||||
requires_grad=False)
|
||||
|
||||
assert self.w13_scale is not None
|
||||
shard_size = self.intermediate_size
|
||||
max_w13_scales = self.w13_scale.max(dim=1).values
|
||||
for expert_id in range(self.num_total_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
self.w13_weight[expert_id][start:start +
|
||||
shard_size, :],
|
||||
self.w13_scale[expert_id][shard_id])
|
||||
self.w13_weight[expert_id][
|
||||
start:start + shard_size, :] = per_tensor_quantize(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
Loading…
Reference in New Issue
Block a user