2024-02-06 09:38:02 +08:00
|
|
|
#pragma once
|
|
|
|
|
|
2024-06-10 04:23:30 +08:00
|
|
|
#include <torch/all.h>
|
2024-02-06 09:38:02 +08:00
|
|
|
|
2024-05-22 15:18:41 +08:00
|
|
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
|
|
|
|
torch::Tensor& token_expert_indices,
|
|
|
|
|
torch::Tensor& gating_output);
|