[Minor] Fix type annotation in fused moe (#3045)

This commit is contained in:
Woosuk Kwon 2024-02-26 19:44:29 -08:00 committed by GitHub
parent 2410e320b3
commit 4bd18ec0c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
@ -137,7 +137,7 @@ def fused_moe_kernel(
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
@ -185,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict): mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1