[Kernel] FP8 support for MoE kernel / Mixtral (#4244)

This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208

It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
This commit is contained in:
Philipp Moritz 2024-04-23 18:18:23 -07:00 committed by GitHub
parent 1e8f4252aa
commit eace8bf0b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 385 additions and 21 deletions

View File

@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")

View File

@ -146,6 +146,11 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);
void scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,

View File

@ -73,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,

View File

@ -0,0 +1,103 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template<typename scalar_t>
__global__ void segmented_max_reduction(
float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
__syncthreads();
// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
}
}
template<typename scalar_t>
__global__ void scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
while (i < num_elems) {
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
i += blockDim.x * gridDim.x;
}
}
} // namespace vllm
void scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(),
input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Tuple
import torch
@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)
# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
return output, scale
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,

View File

@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -21,6 +21,8 @@ def fused_moe_kernel(
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
@ -49,6 +51,7 @@ def fused_moe_kernel(
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
@ -111,6 +114,10 @@ def fused_moe_kernel(
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
if use_fp8:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
@ -129,7 +136,10 @@ def fused_moe_kernel(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
@ -140,7 +150,10 @@ def fused_moe_kernel(
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
if use_fp8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@ -207,15 +220,24 @@ def moe_align_block_size(
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
B_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if not use_fp8:
A_scale = None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A,
B,
C,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
compute_type=compute_type,
use_fp8=use_fp8,
**config,
)
def get_config_file_name(E: int, N: int) -> str:
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
return f"E={E},N={N},device_name={device_name}.json"
dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
@functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N)
json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
@ -288,6 +315,9 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@ -305,6 +335,12 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@ -358,7 +394,8 @@ def fused_moe(
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2])
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
@ -394,17 +431,37 @@ def fused_moe(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, False,
topk_ids.shape[1], config)
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=tl.float16,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, True, 1,
config)
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=tl.float16,
use_fp8=use_fp8)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),

View File

@ -232,6 +232,8 @@ class DefaultModelLoader(BaseModelLoader):
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()

View File

@ -24,6 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

View File

@ -39,6 +39,8 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod,
per_tensor_quantize)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -47,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once
class MixtralMoE(nn.Module):
@ -66,6 +69,7 @@ class MixtralMoE(nn.Module):
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
@ -73,6 +77,9 @@ class MixtralMoE(nn.Module):
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@ -97,6 +104,16 @@ class MixtralMoE(nn.Module):
device="cuda",
dtype=self.params_dtype))
# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
self.w2s_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
@ -118,6 +135,18 @@ class MixtralMoE(nn.Module):
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def process_weights_after_loading(self):
if self.use_fp8:
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
self.ws.data[expert, :, :])
w2s[expert, :, :], self.w2s_scale[
expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
self.ws = nn.Parameter(ws, requires_grad=False)
self.w2s = nn.Parameter(w2s, requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
@ -129,7 +158,10 @@ class MixtralMoE(nn.Module):
router_logits,
self.top_k,
renormalize=True,
inplace=True)
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.ws_scale,
w2_scale=self.w2s_scale)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
@ -171,6 +203,13 @@ class MixtralAttention(nn.Module):
self.rope_theta = rope_theta
self.sliding_window = sliding_window
if isinstance(linear_method, Fp8LinearMethod):
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
linear_method = None
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
@ -238,7 +277,8 @@ class MixtralDecoderLayer(nn.Module):
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size)
intermediate_size=config.intermediate_size,
linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,