2024-05-10 00:19:50 +08:00
|
|
|
#pragma once
|
|
|
|
|
|
2024-06-10 04:23:30 +08:00
|
|
|
#include <torch/all.h>
|
2024-05-10 00:19:50 +08:00
|
|
|
|
|
|
|
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
2024-06-10 04:23:30 +08:00
|
|
|
torch::Tensor indicies, int64_t layer_idx, double scale);
|
2024-05-10 00:19:50 +08:00
|
|
|
|
|
|
|
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|
|
|
|
torch::Tensor indicies, int64_t layer_idx,
|
2024-06-10 04:23:30 +08:00
|
|
|
double scale, int64_t h_in, int64_t h_out,
|
2024-05-10 00:19:50 +08:00
|
|
|
int64_t y_offset);
|