From fa32207842f1ed5a966372ed0513914bff8426c4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 29 Apr 2024 22:05:40 -0700 Subject: [PATCH] [Bugfix][Kernel] Fix compute_type for MoE kernel (#4463) --- vllm/model_executor/layers/fused_moe/fused_moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d37837a0..b4f81527 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -433,6 +433,8 @@ def fused_moe( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) invoke_fused_moe_kernel(hidden_states, w1, @@ -447,7 +449,7 @@ def fused_moe( False, topk_ids.shape[1], config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -465,7 +467,7 @@ def fused_moe( True, 1, config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) if inplace: