[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208 It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This commit is contained in:
parent
1e8f4252aa
commit
eace8bf0b9
@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
|
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/moe_align_block_size_kernels.cu"
|
"csrc/moe_align_block_size_kernels.cu"
|
||||||
"csrc/pybind.cpp")
|
"csrc/pybind.cpp")
|
||||||
|
|||||||
@ -146,6 +146,11 @@ void gptq_shuffle(
|
|||||||
torch::Tensor q_perm,
|
torch::Tensor q_perm,
|
||||||
int bit);
|
int bit);
|
||||||
|
|
||||||
|
void scaled_fp8_quant(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(
|
||||||
torch::Tensor topk_ids,
|
torch::Tensor topk_ids,
|
||||||
int num_experts,
|
int num_experts,
|
||||||
|
|||||||
@ -73,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
|
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||||
ops.def(
|
ops.def(
|
||||||
"moe_align_block_size",
|
"moe_align_block_size",
|
||||||
&moe_align_block_size,
|
&moe_align_block_size,
|
||||||
|
|||||||
103
csrc/quantization/fp8/fp8_cuda_kernels.cu
Normal file
103
csrc/quantization/fp8/fp8_cuda_kernels.cu
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
|
float old;
|
||||||
|
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
||||||
|
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||||
|
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the absolute maximum m of the input tensor and store
|
||||||
|
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||||
|
// reduction tree and the memory in scale is atomically updated.
|
||||||
|
// So to get the right answer, *scale needs to be initialized to
|
||||||
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||||
|
// finish before consuming *scale.
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void segmented_max_reduction(
|
||||||
|
float* __restrict__ scale,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
int64_t num_elems) {
|
||||||
|
__shared__ float cache[1024];
|
||||||
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
// First store maximum for all values processes by
|
||||||
|
// the current thread in cache[threadIdx.x]
|
||||||
|
scalar_t tmp = 0.0;
|
||||||
|
while (i < num_elems) {
|
||||||
|
float x = static_cast<float>(input[i]);
|
||||||
|
tmp = max(tmp, fabs(x));
|
||||||
|
i += blockDim.x * gridDim.x;
|
||||||
|
}
|
||||||
|
cache[threadIdx.x] = tmp;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Now perform parallel reduction within the thread block
|
||||||
|
int ib = blockDim.x / 2;
|
||||||
|
while (ib != 0) {
|
||||||
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||||
|
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
ib /= 2;
|
||||||
|
}
|
||||||
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
|
// atomically write the max to the target location
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void scaled_fp8_quant_kernel(
|
||||||
|
c10::Float8_e4m3fn* __restrict__ out,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
int64_t num_elems) {
|
||||||
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
while (i < num_elems) {
|
||||||
|
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
|
||||||
|
i += blockDim.x * gridDim.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void scaled_fp8_quant(
|
||||||
|
torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input, // [..., d]
|
||||||
|
torch::Tensor& scale) // [1]
|
||||||
|
{
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
|
int64_t num_elems = input.numel();
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(1024);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(),
|
||||||
|
"scaled_fp8_quant_kernel",
|
||||||
|
[&] {
|
||||||
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
scale.data_ptr<float>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
num_elems);
|
||||||
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
scale.data_ptr<float>(),
|
||||||
|
num_elems);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
size_n, size_k)
|
size_n, size_k)
|
||||||
|
|
||||||
|
|
||||||
|
# fp8
|
||||||
|
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||||
|
vllm_ops.scaled_fp8_quant(output, input, scale)
|
||||||
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
# moe
|
# moe
|
||||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||||
block_size: int, sorted_token_ids: torch.Tensor,
|
block_size: int, sorted_token_ids: torch.Tensor,
|
||||||
|
|||||||
@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -21,6 +21,8 @@ def fused_moe_kernel(
|
|||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
c_ptr,
|
c_ptr,
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
topk_weights_ptr,
|
topk_weights_ptr,
|
||||||
sorted_token_ids_ptr,
|
sorted_token_ids_ptr,
|
||||||
expert_ids_ptr,
|
expert_ids_ptr,
|
||||||
@ -49,6 +51,7 @@ def fused_moe_kernel(
|
|||||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
top_k: tl.constexpr,
|
top_k: tl.constexpr,
|
||||||
compute_type: tl.constexpr,
|
compute_type: tl.constexpr,
|
||||||
|
use_fp8: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||||
@ -111,6 +114,10 @@ def fused_moe_kernel(
|
|||||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||||
offs_bn[None, :] * stride_bn)
|
offs_bn[None, :] * stride_bn)
|
||||||
|
|
||||||
|
if use_fp8:
|
||||||
|
a_scale = tl.load(a_scale_ptr)
|
||||||
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Iterate to compute a block of the C matrix.
|
# Iterate to compute a block of the C matrix.
|
||||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||||
@ -129,7 +136,10 @@ def fused_moe_kernel(
|
|||||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||||
other=0.0)
|
other=0.0)
|
||||||
# We accumulate along the K dimension.
|
# We accumulate along the K dimension.
|
||||||
accumulator += tl.dot(a, b)
|
if use_fp8:
|
||||||
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
|
else:
|
||||||
|
accumulator += tl.dot(a, b)
|
||||||
# Advance the ptrs to the next K block.
|
# Advance the ptrs to the next K block.
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
@ -140,7 +150,10 @@ def fused_moe_kernel(
|
|||||||
other=0)
|
other=0)
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
accumulator = accumulator.to(compute_type)
|
if use_fp8:
|
||||||
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||||
|
else:
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
@ -207,15 +220,24 @@ def moe_align_block_size(
|
|||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
B_scale: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
sorted_token_ids: torch.Tensor,
|
sorted_token_ids: torch.Tensor,
|
||||||
expert_ids: torch.Tensor,
|
expert_ids: torch.Tensor,
|
||||||
num_tokens_post_padded: torch.Tensor,
|
num_tokens_post_padded: torch.Tensor,
|
||||||
mul_routed_weight: bool, top_k: int,
|
mul_routed_weight: bool, top_k: int,
|
||||||
config: Dict[str, Any]) -> None:
|
config: Dict[str, Any], compute_type: tl.dtype,
|
||||||
|
use_fp8: bool) -> None:
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
if not use_fp8:
|
||||||
|
A_scale = None
|
||||||
|
assert B_scale is None
|
||||||
|
else:
|
||||||
|
A, A_scale = ops.scaled_fp8_quant(A)
|
||||||
|
assert B_scale is not None
|
||||||
|
|
||||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||||
|
|
||||||
@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
C.stride(2),
|
C.stride(2),
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
|
compute_type=compute_type,
|
||||||
|
use_fp8=use_fp8,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_config_file_name(E: int, N: int) -> str:
|
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||||
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
||||||
return f"E={E},N={N},device_name={device_name}.json"
|
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||||
|
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
|
def get_moe_configs(E: int, N: int,
|
||||||
|
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||||
"""
|
"""
|
||||||
Return optimized configurations for the fused MoE kernel.
|
Return optimized configurations for the fused MoE kernel.
|
||||||
|
|
||||||
@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
|
|||||||
|
|
||||||
# First look up if an optimized configuration is available in the configs
|
# First look up if an optimized configuration is available in the configs
|
||||||
# directory
|
# directory
|
||||||
json_file_name = get_config_file_name(E, N)
|
json_file_name = get_config_file_name(E, N, dtype)
|
||||||
|
|
||||||
config_file_path = os.path.join(
|
config_file_path = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||||
@ -288,6 +315,9 @@ def fused_moe(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
use_fp8: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@ -305,6 +335,12 @@ def fused_moe(
|
|||||||
Defaults to False.
|
Defaults to False.
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||||
for the kernel configuration.
|
for the kernel configuration.
|
||||||
|
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
|
products for w1 and w2. Defaults to False.
|
||||||
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w1.
|
||||||
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w2.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
@ -358,7 +394,8 @@ def fused_moe(
|
|||||||
config = override_config
|
config = override_config
|
||||||
else:
|
else:
|
||||||
# First try to load optimal config from the file
|
# First try to load optimal config from the file
|
||||||
configs = get_moe_configs(E, w2.shape[2])
|
configs = get_moe_configs(E, w2.shape[2],
|
||||||
|
"float8" if use_fp8 else None)
|
||||||
|
|
||||||
if configs:
|
if configs:
|
||||||
# If an optimal configuration map has been found, look up the
|
# If an optimal configuration map has been found, look up the
|
||||||
@ -394,17 +431,37 @@ def fused_moe(
|
|||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||||
topk_ids, config['BLOCK_SIZE_M'], E)
|
topk_ids, config['BLOCK_SIZE_M'], E)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
|
invoke_fused_moe_kernel(hidden_states,
|
||||||
topk_weights, topk_ids, sorted_token_ids,
|
w1,
|
||||||
expert_ids, num_tokens_post_padded, False,
|
intermediate_cache1,
|
||||||
topk_ids.shape[1], config)
|
w1_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
False,
|
||||||
|
topk_ids.shape[1],
|
||||||
|
config,
|
||||||
|
compute_type=tl.float16,
|
||||||
|
use_fp8=use_fp8)
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
|
invoke_fused_moe_kernel(intermediate_cache2,
|
||||||
topk_weights, topk_ids, sorted_token_ids,
|
w2,
|
||||||
expert_ids, num_tokens_post_padded, True, 1,
|
intermediate_cache3,
|
||||||
config)
|
w2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
True,
|
||||||
|
1,
|
||||||
|
config,
|
||||||
|
compute_type=tl.float16,
|
||||||
|
use_fp8=use_fp8)
|
||||||
|
|
||||||
if inplace:
|
if inplace:
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
|||||||
@ -232,6 +232,8 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
linear_method = getattr(module, "linear_method", None)
|
linear_method = getattr(module, "linear_method", None)
|
||||||
if linear_method is not None:
|
if linear_method is not None:
|
||||||
linear_method.process_weights_after_loading(module)
|
linear_method.process_weights_after_loading(module)
|
||||||
|
if hasattr(module, "process_weights_after_loading"):
|
||||||
|
module.process_weights_after_loading()
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ def get_model_architecture(
|
|||||||
# Special handling for quantized Mixtral.
|
# Special handling for quantized Mixtral.
|
||||||
# FIXME(woosuk): This is a temporary hack.
|
# FIXME(woosuk): This is a temporary hack.
|
||||||
if (model_config.quantization is not None
|
if (model_config.quantization is not None
|
||||||
|
and model_config.quantization != "fp8"
|
||||||
and "MixtralForCausalLM" in architectures):
|
and "MixtralForCausalLM" in architectures):
|
||||||
architectures = ["QuantMixtralForCausalLM"]
|
architectures = ["QuantMixtralForCausalLM"]
|
||||||
|
|
||||||
|
|||||||
@ -39,6 +39,8 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod,
|
||||||
|
per_tensor_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -47,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
@ -66,6 +69,7 @@ class MixtralMoE(nn.Module):
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||||
@ -73,6 +77,9 @@ class MixtralMoE(nn.Module):
|
|||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size // self.tp_size
|
self.intermediate_size = intermediate_size // self.tp_size
|
||||||
|
# FIXME(pcmoritz): Make this more general to support different
|
||||||
|
# quantization schemes
|
||||||
|
self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@ -97,6 +104,16 @@ class MixtralMoE(nn.Module):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=self.params_dtype))
|
dtype=self.params_dtype))
|
||||||
|
|
||||||
|
# Scaling factors for FP8 weights
|
||||||
|
self.ws_scale = nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
self.num_total_experts, device="cuda", dtype=torch.float32),
|
||||||
|
requires_grad=False) if self.use_fp8 else None
|
||||||
|
self.w2s_scale = nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
self.num_total_experts, device="cuda", dtype=torch.float32),
|
||||||
|
requires_grad=False) if self.use_fp8 else None
|
||||||
|
|
||||||
set_weight_attrs(self.ws, {
|
set_weight_attrs(self.ws, {
|
||||||
"weight_loader": self.weight_loader,
|
"weight_loader": self.weight_loader,
|
||||||
})
|
})
|
||||||
@ -118,6 +135,18 @@ class MixtralMoE(nn.Module):
|
|||||||
if weight_name.endswith("w2.weight"):
|
if weight_name.endswith("w2.weight"):
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
|
||||||
|
def process_weights_after_loading(self):
|
||||||
|
if self.use_fp8:
|
||||||
|
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
|
||||||
|
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
|
||||||
|
for expert in range(self.num_total_experts):
|
||||||
|
ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
|
||||||
|
self.ws.data[expert, :, :])
|
||||||
|
w2s[expert, :, :], self.w2s_scale[
|
||||||
|
expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
|
||||||
|
self.ws = nn.Parameter(ws, requires_grad=False)
|
||||||
|
self.w2s = nn.Parameter(w2s, requires_grad=False)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
@ -129,7 +158,10 @@ class MixtralMoE(nn.Module):
|
|||||||
router_logits,
|
router_logits,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
inplace=True)
|
inplace=True,
|
||||||
|
use_fp8=self.use_fp8,
|
||||||
|
w1_scale=self.ws_scale,
|
||||||
|
w2_scale=self.w2s_scale)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
@ -171,6 +203,13 @@ class MixtralAttention(nn.Module):
|
|||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
if isinstance(linear_method, Fp8LinearMethod):
|
||||||
|
print_warning_once(
|
||||||
|
"For Mixtral FP8 quantization, we currently do not quantize "
|
||||||
|
"the attention layers until their FP8 performance is improved."
|
||||||
|
)
|
||||||
|
linear_method = None
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -238,7 +277,8 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size)
|
intermediate_size=config.intermediate_size,
|
||||||
|
linear_method=linear_method)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user