[Model] Pipeline parallel support for Mixtral (#6516)
This commit is contained in:
parent
b5241e41d9
commit
b5af8c223c
@ -1,4 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from ..utils import RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
@ -12,6 +13,8 @@ from ..utils import RemoteOpenAIServer
|
|||||||
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
|
||||||
])
|
])
|
||||||
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
pp_args = [
|
pp_args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
@ -34,7 +37,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
|||||||
"--dtype",
|
"--dtype",
|
||||||
"bfloat16",
|
"bfloat16",
|
||||||
"--tensor-parallel-size",
|
"--tensor-parallel-size",
|
||||||
str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model
|
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
|
||||||
"--distributed-executor-backend",
|
"--distributed-executor-backend",
|
||||||
"mp",
|
"mp",
|
||||||
]
|
]
|
||||||
@ -45,8 +48,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
|||||||
pp_args.append("--enforce-eager")
|
pp_args.append("--enforce-eager")
|
||||||
tp_args.append("--enforce-eager")
|
tp_args.append("--enforce-eager")
|
||||||
|
|
||||||
|
prompt = "Hello, my name is"
|
||||||
|
token_ids = tokenizer(prompt)["input_ids"]
|
||||||
results = []
|
results = []
|
||||||
for args in [pp_args, tp_args]:
|
for args in (pp_args, tp_args):
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
||||||
client = server.get_client()
|
client = server.get_client()
|
||||||
|
|
||||||
@ -62,7 +67,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
|||||||
|
|
||||||
# test with text prompt
|
# test with text prompt
|
||||||
completion = client.completions.create(model=MODEL_NAME,
|
completion = client.completions.create(model=MODEL_NAME,
|
||||||
prompt="Hello, my name is",
|
prompt=prompt,
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
@ -76,7 +81,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
|||||||
# test using token IDs
|
# test using token IDs
|
||||||
completion = client.completions.create(
|
completion = client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
prompt=[0, 0, 0, 0, 0],
|
prompt=token_ids,
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
@ -91,7 +96,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
|||||||
# test simple list
|
# test simple list
|
||||||
batch = client.completions.create(
|
batch = client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
prompt=["Hello, my name is", "Hello, my name is"],
|
prompt=[prompt, prompt],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
@ -105,7 +110,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
|
|||||||
# test streaming
|
# test streaming
|
||||||
batch = client.completions.create(
|
batch = client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
prompt=["Hello, my name is", "Hello, my name is"],
|
prompt=[prompt, prompt],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|||||||
@ -34,6 +34,7 @@ _PP_SUPPORTED_MODELS = [
|
|||||||
"MistralForCausalLM",
|
"MistralForCausalLM",
|
||||||
"Phi3ForCausalLM",
|
"Phi3ForCausalLM",
|
||||||
"GPT2LMHeadModel",
|
"GPT2LMHeadModel",
|
||||||
|
"MixtralForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from transformers import MixtralConfig
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
@ -48,6 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA
|
from .interfaces import SupportsLoRA
|
||||||
|
from .utils import is_pp_missing_parameter, make_layers
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
@ -255,12 +256,11 @@ class MixtralModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
MixtralDecoderLayer(config,
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
cache_config,
|
config.num_hidden_layers, lambda: MixtralDecoderLayer(
|
||||||
quant_config=quant_config)
|
config, cache_config, quant_config=quant_config))
|
||||||
for _ in range(config.num_hidden_layers)
|
|
||||||
])
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -269,14 +269,25 @@ class MixtralModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if get_pp_group().is_first_rank:
|
||||||
residual = None
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
for i in range(len(self.layers)):
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
hidden_states, residual = layer(positions, hidden_states,
|
||||||
kv_caches[i], attn_metadata,
|
kv_caches[i - self.start_layer],
|
||||||
residual)
|
attn_metadata, residual)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -347,7 +358,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
attn_metadata)
|
attn_metadata, intermediate_tensors)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
@ -356,6 +367,20 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def make_empty_intermediate_tensors(
|
||||||
|
self, batch_size: int, dtype: torch.dtype,
|
||||||
|
device: torch.device) -> IntermediateTensors:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states":
|
||||||
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device),
|
||||||
|
"residual":
|
||||||
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device),
|
||||||
|
})
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: Optional[torch.Tensor],
|
logits: Optional[torch.Tensor],
|
||||||
@ -392,6 +417,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@ -402,6 +431,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
@ -414,6 +446,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
# Skip layers on other devices.
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
# Remapping the name of FP8 kv-scale.
|
# Remapping the name of FP8 kv-scale.
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
if name is None:
|
if name is None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user