[Bugfix][Kernel] Fix compute_type for MoE kernel (#4463)
This commit is contained in:
parent
d627a3d837
commit
fa32207842
@ -433,6 +433,8 @@ def fused_moe(
|
|||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||||
topk_ids, config['BLOCK_SIZE_M'], E)
|
topk_ids, config['BLOCK_SIZE_M'], E)
|
||||||
|
compute_type = (tl.bfloat16
|
||||||
|
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(hidden_states,
|
invoke_fused_moe_kernel(hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@ -447,7 +449,7 @@ def fused_moe(
|
|||||||
False,
|
False,
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
config,
|
config,
|
||||||
compute_type=tl.float16,
|
compute_type=compute_type,
|
||||||
use_fp8=use_fp8)
|
use_fp8=use_fp8)
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
@ -465,7 +467,7 @@ def fused_moe(
|
|||||||
True,
|
True,
|
||||||
1,
|
1,
|
||||||
config,
|
config,
|
||||||
compute_type=tl.float16,
|
compute_type=compute_type,
|
||||||
use_fp8=use_fp8)
|
use_fp8=use_fp8)
|
||||||
|
|
||||||
if inplace:
|
if inplace:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user