torch_ext/csrc/matrix.cu
2024-11-16 19:26:54 +08:00

201 lines
6.7 KiB
Plaintext

#include "core.h"
#include <iostream>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 32
__global__ void org_mm_cuda(const float *a, const float *b, float *c, int M, int N, int K)
{
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;
// printf("row is %d, col is %d\n", row, col);
int src_len = M * K;
int dest_len = K * N;
if (row < M && col < N)
{
float sum = 0;
for (int i = 0; i < K; i++)
{
int src_index = row * K + i;
int dest_index = i * N + col;
if (src_index < src_len && dest_index < dest_len)
sum += a[src_index] * b[dest_index];
}
c[row * N + col] = sum;
}
}
__device__ void load2shared(const float *a, const float *b, int block_size)
{
int col = threadIdx.x;
int row = threadIdx.y;
bool col_first = true;
bool row_first = true;
// load col or row by the index.
for (int i = 0; i < block_size; i++)
{
if (col_first)
{
// load col by block size
}
else
{
// load row by block size
}
}
}
__global__ void org_mm_cuda_share(const float *a, const float *b, float *c, int M, int N, int K)
{
__shared__ float a_shared[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float b_shared[BLOCK_SIZE][BLOCK_SIZE];
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;
int src_len = M * K;
int dest_len = K * N;
load2shared(a, b, BLOCK_SIZE);
if (row < M && col < N)
{
float sum_value = 0.0;
// copy data to shared memory
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; i++)
{
int src_index = row * K + i * BLOCK_SIZE + threadIdx.x;
if (src_index < src_len)
a_shared[threadIdx.y][threadIdx.x] = a[src_index];
else
a_shared[threadIdx.y][threadIdx.x] = 0.0;
int dest_index = (i * BLOCK_SIZE + threadIdx.y) * N + col;
if (dest_index < dest_len)
b_shared[threadIdx.y][threadIdx.x] = b[dest_index];
else
b_shared[threadIdx.y][threadIdx.x] = 0;
// 32 * 32 thread 都已经完全设定好了。
__syncthreads();
// 每个thread 开始计算对应的知值。但是是不是还是有一些东西是重复的呢?
for (int j = 0; j < BLOCK_SIZE; j++)
{
sum_value += a_shared[threadIdx.y][j] * b_shared[j][threadIdx.x];
}
__syncthreads();
}
c[row * N + col] = sum_value;
}
}
// __device__ __half sigmoid(__half x)
// {
// return __hdiv(__float2half(1.0), __hadd(__float2half(1.0), __hexp(__hneg(x))));
// }
// template <int BLOCK_SIZE, int M, int N, int K>
// void dispatcher()
// {
// }
__global__ void org_mm_cuda_share_half(const __half *a, const __half *b, __half *c, int M, int N, int K)
{
__shared__ __half a_shared[BLOCK_SIZE][BLOCK_SIZE];
__shared__ __half b_shared[BLOCK_SIZE][BLOCK_SIZE];
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;
int src_len = M * K;
int dest_len = K * N;
__half zero_base = __float2half(0.0);
if (row < M && col < N)
{
half sum_value = __float2half(0.0);
// copy data to shared memory
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; i++)
{
int src_index = row * K + i * BLOCK_SIZE + threadIdx.x;
if (src_index < src_len)
a_shared[threadIdx.y][threadIdx.x] = a[src_index];
else
a_shared[threadIdx.y][threadIdx.x] = zero_base;
int dest_index = (i * BLOCK_SIZE + threadIdx.y) * N + col;
if (dest_index < dest_len)
b_shared[threadIdx.y][threadIdx.x] = b[dest_index];
else
b_shared[threadIdx.y][threadIdx.x] = zero_base;
// 32 * 32 thread 都已经完全设定好了。
__syncthreads();
// 每个thread 开始计算对应的知值。但是是不是还是有一些东西是重复的呢?
#pragma unroll
for (int j = 0; j < BLOCK_SIZE; j++)
{
sum_value = __hfma(a_shared[threadIdx.y][j], b_shared[j][threadIdx.x], sum_value);
}
__syncthreads();
}
c[row * N + col] = sum_value;
}
}
void org_mm(const at::Tensor &a, const at::Tensor &b, at::Tensor &c)
{
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);
// printf("block.x is %d, block.y is %d\n", (M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
dim3 block_dim((M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
org_mm_cuda<<<block_dim, threads>>>(a.data_ptr<float>(), b.data_ptr<float>(), c.data_ptr<float>(), M, N, K);
}
void org_mm_shared(const at::Tensor &a, const at::Tensor &b, at::Tensor &c)
{
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);
// printf("block.x is %d, block.y is %d\n", (M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
dim3 block_dim((M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
org_mm_cuda_share<<<block_dim, threads>>>(a.data_ptr<float>(), b.data_ptr<float>(), c.data_ptr<float>(), M, N, K);
}
void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c)
{
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);
// printf("block.x is %d, block.y is %d\n", (M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
dim3 block_dim((M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
org_mm_cuda_share_half<<<block_dim, threads>>>((__half *)a.data_ptr<at::Half>(),
(__half *)b.data_ptr<at::Half>(),
(__half *)c.data_ptr<at::Half>(),
M, N, K);
}
// __global__ void test_restrict(
// __restrict__ float *data_ptr)
// {
// }
// __global__ void test_const(const float *const_v1, float *const const_v2)
// {
// }
// __global__ void rmsnorm(const float *src)
// {
// }
__global__ void print_idx_kernel()
{
printf("this is blockdim.x %d\n", blockDim.x);
printf("this is blockidx.x %d\n", blockIdx.x);
printf("this is threadidx.x %d\n", threadIdx.x);
}
void print_idx()
{
int block_size = 256;
int data_len = 1023;
dim3 blockdim((data_len + block_size - 1) / block_size);
print_idx_kernel<<<blockdim, block_size>>>();
}