271 lines
9.2 KiB
Plaintext
271 lines
9.2 KiB
Plaintext
#include "core.h"
|
|
#include <iostream>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_bf16.h>
|
|
// #include <mma.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
__global__ void add_two_tensors_cuda(const float *input0, const float *input2, float *output, int size)
|
|
{
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
// printf("index is : %d, blockidx is : %d, blockdim is %d, threadidx is %d\n", index, blockIdx.x, blockDim.x, threadIdx.x);
|
|
if (index < size)
|
|
{
|
|
output[index] = input0[index] + input2[index];
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
__device__ float sigmoid(float x)
|
|
{
|
|
return 1.0f / (1.0f + exp(-x));
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void matmul_cuda(const T *in1, const T *in2, T *output, int row, int col, int col2)
|
|
{
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int row_index = index / col2;
|
|
int col_index = index % col2;
|
|
if (index < row * col2)
|
|
{
|
|
float sum = 0;
|
|
for (int i = 0; i < col; i++)
|
|
{
|
|
sum += in1[row_index * col + i] * in2[i * col2 + col_index];
|
|
}
|
|
output[index] = sum;
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void matmul_sigmoid_cuda(const T *in1, const T *in2, T *output, int row, int col, int col2)
|
|
{
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int row_index = index / col2;
|
|
int col_index = index % col2;
|
|
if (index < row * col2)
|
|
{
|
|
float sum = 0;
|
|
output[index] = 0;
|
|
for (int i = 0; i < col; i++)
|
|
{
|
|
sum += in1[row_index * col + i] * in2[i * col2 + col_index];
|
|
}
|
|
output[index] = sigmoid(sum);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
#define BASE_BLOCK 256
|
|
#define CALL_ADD_FUNCTION \
|
|
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
|
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
|
{
|
|
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;
|
|
float *src = input1.data_ptr<float>();
|
|
float *src1 = input2.data_ptr<float>();
|
|
float *dest = output.data_ptr<float>();
|
|
CALL_ADD_FUNCTION;
|
|
}
|
|
|
|
// // write to the columns of the output tensor
|
|
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start)
|
|
{
|
|
printf("what the fuck");
|
|
}
|
|
|
|
void matmul(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
|
{
|
|
float *in1 = input1.data_ptr<float>();
|
|
float *in2 = input2.data_ptr<float>();
|
|
float *in3 = output.data_ptr<float>();
|
|
int row = input1.size(0);
|
|
int col = input1.size(1);
|
|
int col2 = input2.size(1);
|
|
matmul_cuda<float><<<(row * col2 + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(in1, in2, in3, row, col, col2);
|
|
}
|
|
|
|
void matmul_sigmoid(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
|
{
|
|
|
|
float *in1 = input1.data_ptr<float>();
|
|
float *in2 = input2.data_ptr<float>();
|
|
float *in3 = output.data_ptr<float>();
|
|
int row = input1.size(0);
|
|
int col = input1.size(1);
|
|
int col2 = input2.size(1);
|
|
matmul_sigmoid_cuda<float><<<(row * col2 + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(in1, in2, in3, row, col, col2);
|
|
}
|
|
|
|
// template <typename dtype>
|
|
// __global__ void flash_attention(const dtype *q,
|
|
// const dtype *k, const dytype *v, dtype *output,
|
|
// int batch_size, int seq_len, int head_num, int head_dim)
|
|
// {
|
|
// int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
// int batch_index = index / (head_num * seq_len);
|
|
// int head_index = (index % (head_num * seq_len)) / seq_len;
|
|
// int seq_index = index % seq_len;
|
|
// if (index < batch_size * head_num * seq_len)
|
|
// {
|
|
// float sum = 0;
|
|
// for (int i = 0; i < head_dim; i++)
|
|
// {
|
|
// sum += q[batch_index * head_num * seq_len * head_dim + head_index * seq_len * head_dim + i] * k[batch_index * head_num * seq_len * head_dim + head_index * seq_len * head_dim + i];
|
|
// }
|
|
// output[batch_index * head_num * seq_len + head_index * seq_len + seq_index] = sum;
|
|
// }
|
|
// __syncthreads();
|
|
// }
|
|
|
|
#define TILE_WIDTH 32
|
|
__global__ void matrixMul(float *C, float *A, float *B, int width)
|
|
{
|
|
__shared__ float As[TILE_WIDTH][TILE_WIDTH];
|
|
__shared__ float Bs[TILE_WIDTH][TILE_WIDTH];
|
|
|
|
int bx = blockIdx.x;
|
|
int by = blockIdx.y;
|
|
int tx = threadIdx.x;
|
|
int ty = threadIdx.y;
|
|
|
|
int Row = by * TILE_WIDTH + ty;
|
|
int Col = bx * TILE_WIDTH + tx;
|
|
|
|
float Pvalue = 0;
|
|
|
|
for (int m = 0; m < width / TILE_WIDTH; ++m)
|
|
{
|
|
As[ty][tx] = A[Row * width + (m * TILE_WIDTH + tx)];
|
|
Bs[ty][tx] = B[(m * TILE_WIDTH + ty) * width + Col];
|
|
__syncthreads();
|
|
|
|
for (int k = 0; k < TILE_WIDTH; ++k)
|
|
{
|
|
Pvalue += As[ty][k] * Bs[k][tx];
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
C[Row * width + Col] = Pvalue;
|
|
}
|
|
|
|
void matmul_shared(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
|
{
|
|
float *in1 = input1.data_ptr<float>();
|
|
float *in2 = input2.data_ptr<float>();
|
|
float *in3 = output.data_ptr<float>();
|
|
int width = input1.size(0);
|
|
dim3 dimBlock(TILE_WIDTH, TILE_WIDTH);
|
|
dim3 dimGrid(width / TILE_WIDTH, width / TILE_WIDTH);
|
|
matrixMul<<<dimBlock, dimGrid>>>(in1, in2, in3, width);
|
|
}
|
|
|
|
// // 定义矩阵的维度
|
|
// #define WMMA_M 16
|
|
// #define WMMA_N 16
|
|
// #define WMMA_K 16
|
|
|
|
// __global__ void matrixMul_mma(half *A, half *B, float *C, int M, int N, int K)
|
|
// {
|
|
// // 使用shared memory来存储输入矩阵
|
|
// __shared__ half a[WMMA_M][WMMA_K];
|
|
// __shared__ half b[WMMA_K][WMMA_N];
|
|
|
|
// int warpId = threadIdx.x / warpSize;
|
|
// int laneId = threadIdx.x % warpSize;
|
|
|
|
// // 定义用于WMMA的矩阵片段
|
|
// wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
|
|
// wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
|
|
// wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
|
|
|
|
// wmma::fill_fragment(c_frag, 0.0f);
|
|
|
|
// // 加载数据到片段
|
|
// wmma::load_matrix_sync(a_frag, a[warpId], WMMA_K);
|
|
// wmma::load_matrix_sync(b_frag, b[warpId], WMMA_N);
|
|
|
|
// // 执行矩阵乘法
|
|
// wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
|
|
|
// // 将结果存储回全局内存
|
|
// wmma::store_matrix_sync(&C[warpId * WMMA_M * WMMA_N], c_frag, WMMA_N, wmma::mem_row_major);
|
|
// }
|
|
|
|
// at::Tensor &gemm_mma(const at::Tensor &a, const at::Tensor &b)
|
|
// {
|
|
// __half *a_p = reinterpret_cast<__half *>(a.data_ptr<at::Half>());
|
|
// __half *b_p = reinterpret_cast<__half *>(b.data_ptr<at::Half>());
|
|
// at::Tensor c = torch::empty({a.size(0), b.size(1)}, torch::kFloat32);
|
|
// float *c_p = c.data_ptr<float>();
|
|
// matrixMul_mma<<<1, 1>>>(a_p, b_p, c_p, a.size(0), b.size(1), 16);
|
|
// return c;
|
|
// }
|
|
|
|
/*
|
|
// 定义矩阵的维度
|
|
#define MM 1024
|
|
#define NN 1024
|
|
#define KK 1024
|
|
|
|
// 定义 WMMA 矩阵片段的维度
|
|
constexpr int WMMA_M = 16;
|
|
constexpr int WMMA_N = 16;
|
|
constexpr int WMMA_K = 16;
|
|
|
|
// 定义矩阵乘法的 kernel
|
|
__global__ void wmma_gemm(__half *a, __half *b, __half *c)
|
|
{
|
|
// 计算当前线程应处理的矩阵元素的位置
|
|
int wmma_m_index = (blockIdx.x * blockDim.x + threadIdx.x) / WMMA_M;
|
|
int wmma_n_index = (blockIdx.y * blockDim.y + threadIdx.y) / WMMA_N;
|
|
|
|
// 声明 WMMA 矩阵片段
|
|
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, __half, nvcuda::wmma::row_major> a_frag;
|
|
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, __half, nvcuda::wmma::row_major> b_frag;
|
|
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, __half> c_frag;
|
|
|
|
// 初始化 C 矩阵片段
|
|
// nvcuda::wmma::fill_fragment(c_frag, __float2half(0.0f));
|
|
|
|
// 对于 A 和 B 矩阵的每一列
|
|
for (int k = 0; k < KK; k += WMMA_K)
|
|
{
|
|
// 加载 A 和 B 矩阵的片段
|
|
nvcuda::wmma::load_matrix_sync(a_frag, a + wmma_m_index * WMMA_M + k * MM, MM);
|
|
nvcuda::wmma::load_matrix_sync(b_frag, b + k * KK + wmma_n_index * WMMA_N, KK);
|
|
|
|
// 执行矩阵乘法和累加
|
|
nvcuda::wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
|
}
|
|
|
|
// 将结果存储回 C 矩阵
|
|
nvcuda::wmma::store_matrix_sync(c + wmma_m_index * WMMA_M + wmma_n_index * WMMA_N,
|
|
c_frag,
|
|
MM,
|
|
nvcuda::wmma::mem_row_major);
|
|
}
|
|
|
|
at::Tensor &gemm_mma(const at::Tensor &A, const at::Tensor &B)
|
|
{
|
|
// 获取输入 Tensor 的数据指针, 转成half格式的。
|
|
__half *A_data = reinterpret_cast<__half *>(A.data_ptr<at::Half>());
|
|
__half *B_data = reinterpret_cast<__half *>(B.data_ptr<at::Half>());
|
|
|
|
// 创建输出 Tensor
|
|
torch::Tensor C = torch::zeros({A.size(0), B.size(1)}, A.options());
|
|
dim3 grid(MM / WMMA_M, NN / WMMA_N);
|
|
dim3 block(WMMA_M, WMMA_N);
|
|
// 获取输出 Tensor 的数据指针
|
|
__half *C_data = reinterpret_cast<__half *>(C.data_ptr<at::Half>());
|
|
|
|
// 调用你的 CUDA 函数
|
|
wmma_gemm<<<grid, block>>>(A_data, B_data, C_data);
|
|
|
|
return C;
|
|
}
|
|
|
|
*/ |