diff --git a/Dockerfile b/Dockerfile index 4cfcf058..364345d6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,12 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev RUN apt-get update -y \ && apt-get install -y python3-pip git +# Workaround for https://github.com/openai/triton/issues/2507 and +# https://github.com/pytorch/pytorch/issues/107960 -- hopefully +# this won't be needed for future versions of this docker image +# or future versions of triton. +RUN ldconfig /usr/local/cuda-12.1/compat/ + WORKDIR /workspace # install build and runtime dependencies diff --git a/tests/kernels/test_fused_moe.py b/tests/kernels/test_fused_moe.py deleted file mode 100644 index 80a0349d..00000000 --- a/tests/kernels/test_fused_moe.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest -import torch - -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.activation import SiluAndMul - - -def torch_moe(a, w1, w2, topk_weight, topk_ids): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) - out = torch.zeros(B * topk_ids.shape[1], - w2.shape[1], - dtype=a.dtype, - device=a.device) - topk_ids = topk_ids.view(-1) - topk_weight = topk_weight.view(-1) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1)).sum(dim=1) - - -@pytest.mark.parametrize("m", [512, 222, 33, 1]) -@pytest.mark.parametrize("n", [2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", [8, 64]) -@pytest.mark.parametrize("topk", [2, 6]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, -): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 - - score = torch.randn((m, e), device='cuda', dtype=dtype) - score = torch.softmax(score, dim=-1) - topk_weight, topk_ids = torch.topk(score, topk) - - triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False) - torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids) - assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py new file mode 100644 index 00000000..227ddfc3 --- /dev/null +++ b/tests/kernels/test_moe.py @@ -0,0 +1,104 @@ +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_moe.py`. +""" + +import pytest +import torch + +from transformers import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.models.mixtral import MixtralMoE + + +def torch_moe(a, w1, w2, topk_weight, topk_ids): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) + out = torch.zeros(B * topk_ids.shape[1], + w2.shape[1], + dtype=a.dtype, + device=a.device) + topk_ids = topk_ids.view(-1) + topk_weight = topk_weight.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1)).sum(dim=1) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + + score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.softmax(score, dim=-1) + topk_weight, topk_ids = torch.topk(score, topk) + + triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False) + torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids) + assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_mixtral_moe(dtype: torch.dtype): + "Make sure our Mixtral MoE implementation agrees with the one from huggingface." + + # Instantiate our and huggingface's MoE blocks + config = MixtralConfig() + hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") + vllm_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=dtype, + tp_size=1, + ) + + # Load the weights + vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data + for i in range(config.num_local_experts): + weights = (hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data) + vllm_moe.ws[i][:] = torch.cat(weights, dim=0) + vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data + + # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] + inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") + + # Run forward passes for both MoE blocks + hf_states, _ = hf_moe.forward(inputs) + vllm_states = vllm_moe.forward(inputs) + + mixtral_moe_tol = { + torch.float32: 1e-3, + torch.float16: 1e-3, + torch.bfloat16: 1e-2, + } + + assert torch.allclose(hf_states, + vllm_states, + rtol=mixtral_moe_tol[dtype], + atol=mixtral_moe_tol[dtype]) diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py index 998062d8..eed2e83b 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe.py @@ -235,7 +235,9 @@ def fused_moe(hidden_states: torch.Tensor, assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float16, torch.bfloat16] + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] M, _ = hidden_states.shape E, N, _ = w1.shape diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index f36c35fd..a8e47039 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -70,13 +70,14 @@ class MixtralMoE(nn.Module): hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, ): super().__init__() - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.num_total_experts = num_experts self.top_k = top_k self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // tp_size + self.intermediate_size = intermediate_size // self.tp_size if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -141,8 +142,9 @@ class MixtralMoE(nn.Module): selected_experts, inplace=True) - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) return final_hidden_states.view(batch_size, sequence_length, hidden_size)