30 lines
937 B
Plaintext
30 lines
937 B
Plaintext
#include "core.h"
|
|
|
|
#include <cute/tensor.hpp>
|
|
#include <cutlass/cutlass.h>
|
|
#include <cutlass/array.h>
|
|
#include <cutlass/numeric_types.h>
|
|
|
|
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
|
|
{
|
|
int batch_idx = blockIdx.x;
|
|
int head_idx = blockIdx.y;
|
|
int sequence_idx = blockIdx.z;
|
|
int tidx = threadIdx.x;
|
|
}
|
|
|
|
void md_mm(const torch::Tensor &src)
|
|
{
|
|
int batch_size = src.size(0);
|
|
int head_size = src.size(1);
|
|
int sequence_size = src.size(2);
|
|
int head_dim = src.size(3);
|
|
int data_block = sequence_size * head_dim;
|
|
int thread_num = 256;
|
|
dim3 grid(batch_size, head_size, (data_block + thread_num - 1) / thread_num);
|
|
dim3 block(thread_num);
|
|
md_mm_kernel<<<grid, block>>>(reinterpret_cast<float *>(src.data_ptr()),
|
|
src.stride(0), src.stride(1), src.stride(2),
|
|
thread_num);
|
|
}
|