#include "core.h" // calculate the vec cum of different matrix row and col. template __device__ scalar_t vecsum(scalar_t *q, scalar_t *k) { } template __global__ void attention_kernel(const scalar_t *q, const scalar_t *k, const scalar_t *v, int head_num, int head_dim, int seq_len, int batch_size, int hidden_dim, scalar_t *output) { // calculate the gemm. int tid = threadIdx.x; // caculate the offset. int q_offset = blockIdx.x * head_num * 1 * head_dim; int k_offset = blockIdx.x * head_num * seq_len * head_dim; int v_offset = blockIdx.x * head_num * seq_len * head_dim; // calculate the sum. // calculate the softmax // calculate the weighted sum. }