[core][distributed] simplify code to support pipeline parallel (#6406)
This commit is contained in:
parent
44874a0bf9
commit
69672f116c
@ -46,9 +46,7 @@ steps:
|
||||
fast_check: true
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
@ -28,10 +28,8 @@ def test_vllm_gc_ed():
|
||||
assert weak_llm() is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip()
|
||||
and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER",
|
||||
reason="Flashinfer does not support ROCm/HIP.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
@ -40,10 +38,17 @@ def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
backend: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
|
||||
if backend == "FLASHINFER" and is_hip():
|
||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
@ -27,7 +27,6 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -42,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
|
||||
@ -183,18 +184,9 @@ class GPT2Model(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.start_layer, self.end_layer = get_pp_indices(
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
self.h = nn.ModuleList(
|
||||
[nn.Identity() for _ in range(self.start_layer)] + [
|
||||
GPT2Block(config, cache_config, quant_config)
|
||||
for _ in range(self.start_layer, self.end_layer)
|
||||
] + [
|
||||
nn.Identity()
|
||||
for _ in range(self.end_layer, config.num_hidden_layers)
|
||||
])
|
||||
lambda: GPT2Block(config, cache_config, quant_config))
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@ -291,19 +283,20 @@ class GPT2LMHeadModel(nn.Module):
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
try:
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
except KeyError:
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -29,8 +29,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_pp_indices,
|
||||
get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -51,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@ -262,20 +262,11 @@ class LlamaModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.start_layer, self.end_layer = get_pp_indices(
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Identity() for _ in range(self.start_layer)] + [
|
||||
LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
for _ in range(self.start_layer, self.end_layer)
|
||||
] + [
|
||||
nn.Identity()
|
||||
for _ in range(self.end_layer, config.num_hidden_layers)
|
||||
])
|
||||
lambda: LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@ -455,12 +446,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
try:
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
@ -479,13 +472,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
try:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.multimodal import BatchedTensors
|
||||
@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds[mask] = torch.cat(vision_embeddings)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class PPMissingLayer(torch.nn.Identity):
|
||||
"""
|
||||
A placeholder layer for missing layers in a pipeline parallel model.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function, taking
|
||||
pipeline parallelism into account.
|
||||
"""
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
start_layer, end_layer = get_pp_indices(num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] +
|
||||
[layer_fn() for _ in range(start_layer, end_layer)] +
|
||||
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
# NOTE: don't use lru_cache here because it can prevent garbage collection
|
||||
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
|
||||
|
||||
|
||||
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
|
||||
"""Get the names of the missing layers in a pipeline parallel model."""
|
||||
model_id = id(model)
|
||||
if model_id in _model_to_pp_missing_layer_names:
|
||||
return _model_to_pp_missing_layer_names[model_id]
|
||||
|
||||
missing_layer_names = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, PPMissingLayer):
|
||||
missing_layer_names.append(name)
|
||||
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
|
||||
|
||||
return missing_layer_names
|
||||
|
||||
|
||||
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
"""Check if a parameter is missing in a pipeline parallel model."""
|
||||
for missing_layer_name in get_pp_missing_layer_names(model):
|
||||
if name.startswith(missing_layer_name):
|
||||
return True
|
||||
return False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user