#include "core.h" #include #include #include #include __global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num) { int batch_idx = blockIdx.x; int head_idx = blockIdx.y; int sequence_idx = blockIdx.z; int tidx = threadIdx.x; } void md_mm(const torch::Tensor &src) { int batch_size = src.size(0); int head_size = src.size(1); int sequence_size = src.size(2); int head_dim = src.size(3); int data_block = sequence_size * head_dim; int thread_num = 256; dim3 grid(batch_size, head_size, (data_block + thread_num - 1) / thread_num); dim3 block(thread_num); md_mm_kernel<<>>(reinterpret_cast(src.data_ptr()), src.stride(0), src.stride(1), src.stride(2), thread_num); }