torch_ext/csrc/md.cu
2024-11-18 19:54:12 +08:00

30 lines
937 B
Plaintext

#include "core.h"
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
{
int batch_idx = blockIdx.x;
int head_idx = blockIdx.y;
int sequence_idx = blockIdx.z;
int tidx = threadIdx.x;
}
void md_mm(const torch::Tensor &src)
{
int batch_size = src.size(0);
int head_size = src.size(1);
int sequence_size = src.size(2);
int head_dim = src.size(3);
int data_block = sequence_size * head_dim;
int thread_num = 256;
dim3 grid(batch_size, head_size, (data_block + thread_num - 1) / thread_num);
dim3 block(thread_num);
md_mm_kernel<<<grid, block>>>(reinterpret_cast<float *>(src.data_ptr()),
src.stride(0), src.stride(1), src.stride(2),
thread_num);
}