torch_ext/csrc/attention.cu
2025-01-04 13:47:42 +08:00

29 lines
996 B
Plaintext

#include "core.h"
// calculate the vec cum of different matrix row and col.
template <typename scalar_t>
__device__ scalar_t vecsum(scalar_t *q, scalar_t *k)
{
}
template <typename scalar_t>
__global__ void attention_kernel(const scalar_t *q,
const scalar_t *k,
const scalar_t *v,
int head_num,
int head_dim,
int seq_len,
int batch_size,
int hidden_dim,
scalar_t *output)
{
// calculate the gemm.
int tid = threadIdx.x;
// caculate the offset.
int q_offset = blockIdx.x * head_num * 1 * head_dim;
int k_offset = blockIdx.x * head_num * seq_len * head_dim;
int v_offset = blockIdx.x * head_num * seq_len * head_dim;
// calculate the sum.
// calculate the softmax
// calculate the weighted sum.
}