diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index 81cc6dd6..de6a0ec0 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -95,7 +95,7 @@ void moe_align_block_size( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); assert(num_experts <= NUM_MAX_EXPERTS); VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] { + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { vllm::moe_align_block_size_kernel<<<1, num_experts, 0, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), diff --git a/csrc/ops.h b/csrc/ops.h index 6e52dd81..2bcd0c2e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -100,6 +100,13 @@ void gptq_shuffle( torch::Tensor q_weight, torch::Tensor q_perm); +void moe_align_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM using fptr_t = uint64_t; @@ -121,12 +128,3 @@ std::pair, std::vector> get_graph_buffer_ipc_meta( void register_graph_buffers(fptr_t _fa, const std::vector &handles, const std::vector> &offsets); #endif - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad - ); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a8a99883..8a823569 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def( - "moe_align_block_size", - &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + "moe_align_block_size", + &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a8dadce2..f36c35fd 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -23,8 +23,6 @@ """Inference-only Mixtral model.""" from typing import List, Optional, Tuple -import numpy as np - import torch import torch.nn.functional as F @@ -33,10 +31,11 @@ from transformers import MixtralConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - ReplicatedLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] -class MixtralMLP(nn.Module): +class MixtralMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ def __init__( self, num_experts: int, + top_k: int, hidden_size: int, intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - linear_method=linear_method) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + params_dtype: Optional[torch.dtype] = None, ): super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.num_total_experts), self.tp_size)[self.rank].tolist() - if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") + tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // tp_size - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - linear_method=linear_method) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, + params_dtype=self.params_dtype, linear_method=None) + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + batch_size, sequence_length, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (batch * sequence_length, n_experts) router_logits, _ = self.gate(hidden_states) @@ -142,22 +134,18 @@ class MixtralMoE(nn.Module): dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + routing_weights, + selected_experts, + inplace=True) - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) - return tensor_model_parallel_all_reduce(final_hidden_states).view( - batch_size, sequence_length, hidden_dim) + return final_hidden_states.view(batch_size, sequence_length, + hidden_size) class MixtralAttention(nn.Module): @@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module): rope_theta=rope_theta, sliding_window=config.sliding_window, linear_method=linear_method) - self.block_sparse_moe = MixtralMoE(config=config, - linear_method=linear_method) + self.block_sparse_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module): ("qkv_proj", "v_proj", "v"), ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, @@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module): fall_back_to_pt=False): if "rotary_emb.inv_freq" in name: continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue @@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)