Add unit test for Mixtral MoE layer (#2677)
This commit is contained in:
parent
89efcf1ce5
commit
d0d93b92b1
@ -7,6 +7,12 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip git
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-12.1/compat/
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
|
||||
@ -1,50 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, topk_weight, topk_ids):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk_ids.shape[1],
|
||||
w2.shape[1],
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||
104
tests/kernels/test_moe.py
Normal file
104
tests/kernels/test_moe.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Tests for the MOE layers.
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, topk_weight, topk_ids):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk_ids.shape[1],
|
||||
w2.shape[1],
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(dtype: torch.dtype):
|
||||
"Make sure our Mixtral MoE implementation agrees with the one from huggingface."
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
config = MixtralConfig()
|
||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||
vllm_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
params_dtype=dtype,
|
||||
tp_size=1,
|
||||
)
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
vllm_moe.ws[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(inputs)
|
||||
vllm_states = vllm_moe.forward(inputs)
|
||||
|
||||
mixtral_moe_tol = {
|
||||
torch.float32: 1e-3,
|
||||
torch.float16: 1e-3,
|
||||
torch.bfloat16: 1e-2,
|
||||
}
|
||||
|
||||
assert torch.allclose(hf_states,
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
@ -235,7 +235,9 @@ def fused_moe(hidden_states: torch.Tensor,
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
M, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
|
||||
|
||||
@ -70,13 +70,14 @@ class MixtralMoE(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size // tp_size
|
||||
self.intermediate_size = intermediate_size // self.tp_size
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -141,8 +142,9 @@ class MixtralMoE(nn.Module):
|
||||
selected_experts,
|
||||
inplace=True)
|
||||
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(batch_size, sequence_length,
|
||||
hidden_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user