[ Misc ] Apply MoE Refactor to Deepseekv2 To Support Fp8 (#6417)
This commit is contained in:
parent
eeceadaecc
commit
fb6af8bc08
@ -0,0 +1,11 @@
|
|||||||
|
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2
|
||||||
|
model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.671
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.664
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@ -1,3 +1,4 @@
|
|||||||
Meta-Llama-3-70B-Instruct.yaml
|
Meta-Llama-3-70B-Instruct.yaml
|
||||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||||
Qwen2-57B-A14-Instruct.yaml
|
Qwen2-57B-A14-Instruct.yaml
|
||||||
|
DeepSeek-V2-Lite-Chat.yaml
|
||||||
|
|||||||
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
lm_eval --model vllm \
|
lm_eval --model vllm \
|
||||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \
|
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true \
|
||||||
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
||||||
--batch_size $BATCH_SIZE
|
--batch_size $BATCH_SIZE
|
||||||
|
|||||||
@ -394,14 +394,16 @@ def fused_topk(
|
|||||||
|
|
||||||
|
|
||||||
# This is used by the Deepseek-V2 model
|
# This is used by the Deepseek-V2 model
|
||||||
def grouped_topk(
|
def grouped_topk(hidden_states: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
topk: int,
|
||||||
topk: int,
|
renormalize: bool,
|
||||||
renormalize: bool,
|
num_expert_group: int = 0,
|
||||||
num_expert_group: int = 0,
|
topk_group: int = 0):
|
||||||
topk_group: int = 0,
|
|
||||||
):
|
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||||
|
"Number of tokens mismatch")
|
||||||
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
num_token = scores.shape[0]
|
num_token = scores.shape[0]
|
||||||
group_scores = scores.view(num_token, num_expert_group,
|
group_scores = scores.view(num_token, num_expert_group,
|
||||||
@ -557,6 +559,9 @@ def fused_moe(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
@ -579,6 +584,10 @@ def fused_moe(
|
|||||||
Defaults to False.
|
Defaults to False.
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||||
for the kernel configuration.
|
for the kernel configuration.
|
||||||
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||||
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||||
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||||
|
note: Deepseekv2 model uses grouped_topk
|
||||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
products for w1 and w2. Defaults to False.
|
products for w1 and w2. Defaults to False.
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
@ -592,8 +601,15 @@ def fused_moe(
|
|||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
|
|
||||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
if use_grouped_topk:
|
||||||
renormalize)
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
|
||||||
|
topk, renormalize,
|
||||||
|
num_expert_group, topk_group)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||||
|
renormalize)
|
||||||
|
|
||||||
return fused_experts(hidden_states,
|
return fused_experts(hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -29,7 +29,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True) -> torch.Tensor:
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -63,7 +66,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True) -> torch.Tensor:
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||||
|
|
||||||
return fused_moe(x,
|
return fused_moe(x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -71,7 +77,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
router_logits,
|
router_logits,
|
||||||
top_k,
|
top_k,
|
||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
inplace=True)
|
inplace=True,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group)
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
@ -104,6 +113,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
reduce_results: bool = False,
|
reduce_results: bool = False,
|
||||||
renormalize: bool = True,
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@ -119,6 +131,11 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
self.renormalize = renormalize
|
self.renormalize = renormalize
|
||||||
|
self.use_grouped_topk = use_grouped_topk
|
||||||
|
if self.use_grouped_topk:
|
||||||
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
self.num_expert_group = num_expert_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
@ -140,9 +157,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: int, expert_id: int):
|
shard_id: int, expert_id: int):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
|
|
||||||
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
|
# Input scales can be loaded directly and should be equal.
|
||||||
# Follow up PR to enable fp8 for other MoE models.
|
if "input_scale" in weight_name:
|
||||||
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
|
|
||||||
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
||||||
loaded_weight).abs() > 1e-5:
|
loaded_weight).abs() > 1e-5:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -150,14 +166,21 @@ class FusedMoE(torch.nn.Module):
|
|||||||
f"must be equal. But got {param_data[expert_id]} "
|
f"must be equal. But got {param_data[expert_id]} "
|
||||||
f"vs. {loaded_weight}")
|
f"vs. {loaded_weight}")
|
||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
|
# Weight scales
|
||||||
# Follow up PR to enable fp8 for other MoE models.
|
|
||||||
elif "weight_scale" in weight_name:
|
elif "weight_scale" in weight_name:
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
# If we are in merged column case (gate_up_proj)
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
# shard_id 0 == gate_proj / w1
|
||||||
assert "w1" in weight_name or "w3" in weight_name
|
# shard_id 2 == up_proj / w3
|
||||||
shard_id = 0 if "w1" in weight_name else 1
|
if shard_id == 0 or shard_id == 2:
|
||||||
param_data[expert_id][shard_id] = loaded_weight
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
|
idx = 0 if shard_id == 0 else 1
|
||||||
|
param_data[expert_id][idx] = loaded_weight
|
||||||
|
# If we are in the row parallel case (down_proj)
|
||||||
|
# shard_id 1 == down_proj / w2
|
||||||
|
else:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
# Weights
|
||||||
else:
|
else:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
shard_size = self.intermediate_size_per_partition
|
shard_size = self.intermediate_size_per_partition
|
||||||
@ -188,10 +211,50 @@ class FusedMoE(torch.nn.Module):
|
|||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
renormalize=self.renormalize)
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
topk_group=self.topk_group)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_params_mapping(
|
||||||
|
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
||||||
|
ckpt_up_proj_name: str,
|
||||||
|
num_experts: int) -> List[Tuple[str, str, int, int]]:
|
||||||
|
|
||||||
|
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||||
|
gate_down_up = [
|
||||||
|
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
# These are the weight scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
("experts.w13_scale"
|
||||||
|
if weight_name in gate_up else "experts.w2_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
|
||||||
|
shard_id) for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in enumerate(gate_down_up)
|
||||||
|
] + [
|
||||||
|
# These are the weights for the experts
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
("experts.w13_weight"
|
||||||
|
if weight_name in gate_up else "experts.w2_weight",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in enumerate(gate_down_up)
|
||||||
|
] + [
|
||||||
|
# These are the weight scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
("experts.a13_scale"
|
||||||
|
if weight_name in gate_up else "experts.a2_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
|
||||||
|
shard_id) for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in enumerate(gate_down_up)
|
||||||
|
]
|
||||||
|
|||||||
@ -377,7 +377,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
renormalize: bool = True) -> torch.Tensor:
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||||
|
|
||||||
return fused_moe(x,
|
return fused_moe(x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -390,7 +393,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
w1_scale=layer.w13_scale,
|
w1_scale=layer.w13_scale,
|
||||||
w2_scale=layer.w2_scale,
|
w2_scale=layer.w2_scale,
|
||||||
a1_scale=layer.a13_scale,
|
a1_scale=layer.a13_scale,
|
||||||
a2_scale=layer.a2_scale)
|
a2_scale=layer.a2_scale,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group)
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(QuantizeMethodBase):
|
class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||||
|
|||||||
@ -29,11 +29,10 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts, grouped_topk
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@ -91,32 +90,34 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
|
||||||
self.rank = get_tensor_model_parallel_rank()
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.n_routed_experts = config.n_routed_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
if self.tp_size > self.n_routed_experts:
|
self.n_shared_experts = config.n_shared_experts
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
if self.tp_size > config.n_routed_experts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
f"the number of experts {self.n_routed_experts}.")
|
f"the number of experts {config.n_routed_experts}.")
|
||||||
|
|
||||||
self.experts = nn.ModuleList([
|
if config.hidden_act != "silu":
|
||||||
DeepseekV2MLP(hidden_size=config.hidden_size,
|
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||||
intermediate_size=config.moe_intermediate_size,
|
"Only silu is supported for now.")
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
quant_config=quant_config,
|
self.experts = FusedMoE(num_experts=config.n_routed_experts,
|
||||||
reduce_results=False)
|
top_k=config.num_experts_per_tok,
|
||||||
for idx in range(self.n_routed_experts)
|
hidden_size=config.hidden_size,
|
||||||
])
|
intermediate_size=config.moe_intermediate_size,
|
||||||
self.pack_params()
|
reduce_results=False,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_grouped_topk=True,
|
||||||
|
num_expert_group=config.n_group,
|
||||||
|
topk_group=config.topk_group)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
self.n_routed_experts,
|
config.n_routed_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=None)
|
quant_config=None)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
intermediate_size = (config.moe_intermediate_size *
|
intermediate_size = (config.moe_intermediate_size *
|
||||||
config.n_shared_experts)
|
config.n_shared_experts)
|
||||||
@ -128,50 +129,21 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def pack_params(self):
|
|
||||||
w1 = []
|
|
||||||
w2 = []
|
|
||||||
for expert in self.experts:
|
|
||||||
w1.append(expert.gate_up_proj.weight)
|
|
||||||
w2.append(expert.down_proj.weight)
|
|
||||||
self.w1 = torch._utils._flatten_dense_tensors(w1)
|
|
||||||
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
|
||||||
for data, param in zip(w1s, w1):
|
|
||||||
param.data = data
|
|
||||||
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
|
||||||
|
|
||||||
self.w2 = torch._utils._flatten_dense_tensors(w2)
|
|
||||||
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
|
||||||
for data, param in zip(w2s, w2):
|
|
||||||
param.data = data
|
|
||||||
|
|
||||||
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
if self.config.n_shared_experts is not None:
|
if self.n_shared_experts is not None:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
topk_weights, topk_ids = grouped_topk(
|
final_hidden_states = self.experts(
|
||||||
hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits,
|
router_logits=router_logits) * self.routed_scaling_factor
|
||||||
self.top_k,
|
if shared_output is not None:
|
||||||
renormalize=self.config.norm_topk_prob,
|
|
||||||
num_expert_group=self.config.n_group,
|
|
||||||
topk_group=self.config.topk_group)
|
|
||||||
final_hidden_states = fused_experts(
|
|
||||||
hidden_states,
|
|
||||||
self.w1,
|
|
||||||
self.w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=True) * self.routed_scaling_factor
|
|
||||||
if self.config.n_shared_experts is not None:
|
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
if self.tp_size > 1:
|
||||||
final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
|
||||||
@ -504,34 +476,58 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
ckpt_down_proj_name="down_proj",
|
||||||
|
ckpt_up_proj_name="up_proj",
|
||||||
|
num_experts=self.config.n_routed_experts)
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||||
|
# Since we handle the experts below in expert_params_mapping,
|
||||||
|
# we need to skip here BEFORE we update the name, otherwise
|
||||||
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||||
|
# will then be updated below in expert_params_mapping
|
||||||
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||||
|
if (("mlp.experts." in name) and name not in params_dict):
|
||||||
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Skip experts that are not assigned to this worker.
|
|
||||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
for mapping in expert_params_mapping:
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
continue
|
if weight_name not in name:
|
||||||
# Skip experts that are not assigned to this worker.
|
continue
|
||||||
if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
name = name.replace(weight_name, param_name)
|
||||||
and name not in params_dict):
|
param = params_dict[name]
|
||||||
continue
|
weight_loader = param.weight_loader
|
||||||
param = params_dict[name]
|
weight_loader(param,
|
||||||
weight_loader = getattr(param, "weight_loader",
|
loaded_weight,
|
||||||
default_weight_loader)
|
weight_name,
|
||||||
weight_loader(param, loaded_weight)
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@ -372,31 +372,13 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = [
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# These are the weight scales for the experts
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
("experts.w13_scale"
|
ckpt_gate_proj_name="w1",
|
||||||
if weight_name in ["w1", "w3"] else "experts.w2_scale",
|
ckpt_down_proj_name="w2",
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
|
ckpt_up_proj_name="w3",
|
||||||
shard_id) for expert_id in range(self.config.num_local_experts)
|
num_experts=self.config.num_local_experts)
|
||||||
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
|
||||||
] + [
|
|
||||||
# These are the weights for the experts
|
|
||||||
# (param_name, weight_name, expert_id)
|
|
||||||
("experts.w13_weight"
|
|
||||||
if weight_name in ["w1", "w3"] else "experts.w2_weight",
|
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
|
||||||
for expert_id in range(self.config.num_local_experts)
|
|
||||||
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
|
||||||
] + [
|
|
||||||
# These are the activation scales for the experts
|
|
||||||
# (param_name, weight_name, expert_id)
|
|
||||||
("experts.a13_scale"
|
|
||||||
if weight_name in ["w1", "w3"] else "experts.a2_scale",
|
|
||||||
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
|
|
||||||
shard_id) for expert_id in range(self.config.num_local_experts)
|
|
||||||
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
|||||||
@ -50,6 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
class Qwen2MoeMLP(nn.Module):
|
||||||
@ -406,15 +407,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = [
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# These are the weights for the experts
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
|
ckpt_gate_proj_name="gate_proj",
|
||||||
else "experts.w2_weight",
|
ckpt_down_proj_name="down_proj",
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
ckpt_up_proj_name="up_proj",
|
||||||
for expert_id in range(self.config.num_experts) for shard_id,
|
num_experts=self.config.num_experts)
|
||||||
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@ -461,8 +460,20 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name not in params_dict:
|
# Remapping the name of FP8 kv-scale.
|
||||||
continue
|
if name.endswith("kv_scale"):
|
||||||
|
remapped_kv_scale_name = name.replace(
|
||||||
|
".kv_scale", ".attn.kv_scale")
|
||||||
|
if remapped_kv_scale_name not in params_dict:
|
||||||
|
print_warning_once(
|
||||||
|
"Found kv scale in the checkpoint "
|
||||||
|
f"(e.g. {name}), but not found the expected "
|
||||||
|
f"name in the model "
|
||||||
|
f"(e.g. {remapped_kv_scale_name}). "
|
||||||
|
"kv-scale is not loaded.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
name = remapped_kv_scale_name
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user