[model][utils] add extract_layer_index utility function (#10599)
Some checks failed
codespell / codespell (3.12) (push) Has been cancelled
mypy / mypy (3.10) (push) Has been cancelled
mypy / mypy (3.11) (push) Has been cancelled
mypy / mypy (3.12) (push) Has been cancelled
mypy / mypy (3.9) (push) Has been cancelled
ruff / ruff (3.12) (push) Has been cancelled
yapf / yapf (3.12) (push) Has been cancelled

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-23 22:22:54 -08:00 committed by GitHub
parent eda2b3589c
commit c055747867
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 59 additions and 51 deletions

View File

@ -33,7 +33,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -44,15 +44,14 @@ class ArcticMLP(nn.Module):
def __init__(self,
config: ArcticConfig,
layer_id: int,
expert_id: int = -1,
is_residual_mlp: bool = False,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
reduce_results: bool = True,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.expert_id = expert_id
self.layer_id = layer_id
self.ffn_dim = config.intermediate_size if not is_residual_mlp \
else self.hidden_size
@ -85,13 +84,14 @@ class ArcticMoE(nn.Module):
def __init__(self,
config: ArcticConfig,
layer_id: int,
tp_size: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
reduce_results: bool = True,
prefix: str = ""):
super().__init__()
layer_id = extract_layer_index(prefix)
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.hidden_size = config.hidden_size
self.num_experts = config.num_local_experts
@ -109,15 +109,16 @@ class ArcticMoE(nn.Module):
if not self.is_moe_layer:
self.mlp = ArcticMLP(config,
layer_id=layer_id,
quant_config=quant_config,
reduce_results=reduce_results)
reduce_results=reduce_results,
prefix=f"{prefix}.mlp")
else:
self.gate = ReplicatedLinear(self.hidden_size,
self.num_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate")
if self.is_quant:
self.ws = DeepSpeedFPParameter(
torch.Size((self.num_experts, 2 * self.intermediate_size,
@ -220,14 +221,12 @@ class ArcticAttention(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
@ -298,26 +297,25 @@ class ArcticDecoderLayer(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
layer_idx = extract_layer_index(prefix)
is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
self.use_residual = config.use_residual and is_moe_layer
self.self_attn = ArcticAttention(config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = ArcticMoE(
config,
layer_id=layer_idx,
quant_config=quant_config,
reduce_results=(not self.use_residual))
reduce_results=(not self.use_residual),
prefix=f"{prefix}.block_sparse_moe",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -328,9 +326,9 @@ class ArcticDecoderLayer(nn.Module):
self.residual_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.residual_mlp = ArcticMLP(config,
layer_id=layer_idx,
is_residual_mlp=True,
reduce_results=False)
reduce_results=False,
prefix=f"{prefix}.residual_mlp")
def forward(
self,
@ -384,11 +382,8 @@ class ArcticModel(nn.Module):
org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
lambda prefix: ArcticDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -63,6 +63,7 @@ class DeepseekMLP(nn.Module):
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -92,6 +93,7 @@ class DeepseekMoE(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -260,12 +262,12 @@ class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
@ -285,13 +287,16 @@ class DeepseekDecoderLayer(nn.Module):
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
self.mlp = DeepseekMoE(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -347,11 +352,9 @@ class DeepseekModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: DeepseekDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (

View File

@ -42,7 +42,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -85,7 +86,6 @@ class Gemma2MLP(nn.Module):
class Gemma2Attention(nn.Module):
def __init__(self,
layer_idx: int,
config: Gemma2Config,
hidden_size: int,
num_heads: int,
@ -98,7 +98,6 @@ class Gemma2Attention(nn.Module):
attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@ -145,6 +144,7 @@ class Gemma2Attention(nn.Module):
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
sliding_window = config.interleaved_sliding_window if \
@ -178,7 +178,6 @@ class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
@ -187,7 +186,6 @@ class Gemma2DecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention(
layer_idx=layer_idx,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@ -262,11 +260,8 @@ class Gemma2Model(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
config,
cache_config,
quant_config,
prefix=prefix),
lambda prefix: Gemma2DecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -181,7 +181,6 @@ class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -264,11 +263,8 @@ class OlmoeModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
lambda prefix: OlmoeDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

View File

@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -244,7 +244,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -269,6 +268,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (
@ -337,8 +337,6 @@ class Qwen2MoeModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2MoeDecoderLayer(config=config,
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),

View File

@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
return name if not prefix else f"{prefix}.{name}"
def extract_layer_index(layer_name: str) -> int:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames = layer_name.split(".")
int_vals: List[int] = []
for subname in subnames:
try:
int_vals.append(int(subname))
except ValueError:
continue
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]