29 lines
996 B
Plaintext
29 lines
996 B
Plaintext
|
|
#include "core.h"
|
||
|
|
|
||
|
|
// calculate the vec cum of different matrix row and col.
|
||
|
|
template <typename scalar_t>
|
||
|
|
__device__ scalar_t vecsum(scalar_t *q, scalar_t *k)
|
||
|
|
{
|
||
|
|
}
|
||
|
|
|
||
|
|
template <typename scalar_t>
|
||
|
|
__global__ void attention_kernel(const scalar_t *q,
|
||
|
|
const scalar_t *k,
|
||
|
|
const scalar_t *v,
|
||
|
|
int head_num,
|
||
|
|
int head_dim,
|
||
|
|
int seq_len,
|
||
|
|
int batch_size,
|
||
|
|
int hidden_dim,
|
||
|
|
scalar_t *output)
|
||
|
|
{
|
||
|
|
// calculate the gemm.
|
||
|
|
int tid = threadIdx.x;
|
||
|
|
// caculate the offset.
|
||
|
|
int q_offset = blockIdx.x * head_num * 1 * head_dim;
|
||
|
|
int k_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||
|
|
int v_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||
|
|
// calculate the sum.
|
||
|
|
// calculate the softmax
|
||
|
|
// calculate the weighted sum.
|
||
|
|
}
|