[Bugfix][Kernel] Fix compute_type for MoE kernel (#4463)

This commit is contained in:
Woosuk Kwon 2024-04-29 22:05:40 -07:00 committed by GitHub
parent d627a3d837
commit fa32207842
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -433,6 +433,8 @@ def fused_moe(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states,
w1,
@ -447,7 +449,7 @@ def fused_moe(
False,
topk_ids.shape[1],
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@ -465,7 +467,7 @@ def fused_moe(
True,
1,
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)
if inplace: