[Minor] Fix type annotation in fused moe (#3045)
This commit is contained in:
parent
2410e320b3
commit
4bd18ec0c7
@ -2,7 +2,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@ -137,7 +137,7 @@ def fused_moe_kernel(
|
|||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int,
|
topk_ids: torch.Tensor, block_size: int,
|
||||||
num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
|
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
|
||||||
|
|
||||||
@ -185,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
sorted_token_ids: torch.Tensor,
|
sorted_token_ids: torch.Tensor,
|
||||||
expert_ids: torch.Tensor,
|
expert_ids: torch.Tensor,
|
||||||
num_tokens_post_padded: torch.Tensor,
|
num_tokens_post_padded: torch.Tensor,
|
||||||
mul_routed_weight: bool, top_k: int, config: dict):
|
mul_routed_weight: bool, top_k: int,
|
||||||
|
config: Dict[str, Any]) -> None:
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user