[Bugfix][Kernel] Fix compute_type for MoE kernel (#4463)
This commit is contained in:
parent
d627a3d837
commit
fa32207842
@ -433,6 +433,8 @@ def fused_moe(
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config['BLOCK_SIZE_M'], E)
|
||||
compute_type = (tl.bfloat16
|
||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
@ -447,7 +449,7 @@ def fused_moe(
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=tl.float16,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
@ -465,7 +467,7 @@ def fused_moe(
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=tl.float16,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
|
||||
if inplace:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user