From c467dff24f5dfa0b8e4045c3d265c6e606e35f99 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 16 Jul 2024 09:56:28 -0700 Subject: [PATCH] [Hardware][TPU] Support MoE with Pallas GMM kernel (#6457) --- Dockerfile.tpu | 4 +- .../getting_started/tpu-installation.rst | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 18 ++++++ .../layers/fused_moe/moe_pallas.py | 62 +++++++++++++++++++ vllm/worker/tpu_model_runner.py | 9 ++- 5 files changed, 89 insertions(+), 8 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/moe_pallas.py diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 6ad8e8cc..be7dbe63 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240601" +ARG NIGHTLY_DATE="20240713" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE @@ -6,6 +6,8 @@ WORKDIR /workspace # Install aiohttp separately to avoid build errors. RUN pip install aiohttp +# Install NumPy 1 instead of NumPy 2. +RUN pip install "numpy<2" # Install the TPU and Pallas dependencies. RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index e96aabbb..5e2f514a 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,7 +56,7 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="+20240601" + $ export DATE="+20240713" $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl @@ -85,7 +85,7 @@ Next, build vLLM from source. This will only take a few seconds: ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory - You can install OpenBLAS with the following command: + Please install OpenBLAS with the following command: .. code-block:: console diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7f066860..bb2be3f3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -104,6 +104,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): raise NotImplementedError( "The CPU backend currently does not support MoE.") + def forward_tpu( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + return fused_moe(x, w1, w2, router_logits, top_k, renormalize) + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py new file mode 100644 index 00000000..563ee18c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -0,0 +1,62 @@ +import torch +import torch.nn.functional as F +from torch_xla.experimental.custom_kernel import _histogram + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + device = hidden_states.device + dtype = hidden_states.dtype + assert (num_tokens * topk) % 16 == 0, ( + "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " + f"16 but got {num_tokens * topk}") + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, topk_indices = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + topk_indices = topk_indices.flatten() + topk_argsort_indices = topk_indices.argsort() + topk_argsort_revert_indices = topk_argsort_indices.argsort() + token_indices = torch.arange(num_tokens, + device=device).repeat_interleave(topk) + token_indices = token_indices[topk_argsort_indices] + group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) + + # NOTE(woosuk): The GMM Pallas kernel requires a different weight layout + # from HF Transformers. + w1 = w1.transpose(1, 2) + w2 = w2.transpose(1, 2) + + x = hidden_states[token_indices] + x = torch.ops.xla.gmm(x, w1, group_sizes) + x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] + x = torch.ops.xla.gmm(x, w2, group_sizes) + x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) + + x = x * topk_weights.unsqueeze_(dim=-1) + x = x.sum(dim=-2) + x = x.reshape(orig_shape) + return x diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 9b00a60a..6c1149ee 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -598,11 +598,10 @@ def _get_padded_prefill_len(x: int) -> int: def _get_padded_batch_size(batch_size: int) -> int: - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - elif batch_size <= 8: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: return 8 else: return ((batch_size + 15) // 16) * 16