201 lines
6.7 KiB
Plaintext
201 lines
6.7 KiB
Plaintext
#include "core.h"
|
|
#include <iostream>
|
|
#include <cuda_fp16.h>
|
|
#include <mma.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
#define BLOCK_SIZE 32
|
|
|
|
__global__ void org_mm_cuda(const float *a, const float *b, float *c, int M, int N, int K)
|
|
{
|
|
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
|
|
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;
|
|
// printf("row is %d, col is %d\n", row, col);
|
|
int src_len = M * K;
|
|
int dest_len = K * N;
|
|
if (row < M && col < N)
|
|
{
|
|
float sum = 0;
|
|
for (int i = 0; i < K; i++)
|
|
{
|
|
int src_index = row * K + i;
|
|
int dest_index = i * N + col;
|
|
if (src_index < src_len && dest_index < dest_len)
|
|
sum += a[src_index] * b[dest_index];
|
|
}
|
|
c[row * N + col] = sum;
|
|
}
|
|
}
|
|
|
|
__device__ void load2shared(const float *a, const float *b, int block_size)
|
|
{
|
|
int col = threadIdx.x;
|
|
int row = threadIdx.y;
|
|
bool col_first = true;
|
|
bool row_first = true;
|
|
// load col or row by the index.
|
|
for (int i = 0; i < block_size; i++)
|
|
{
|
|
if (col_first)
|
|
{
|
|
// load col by block size
|
|
}
|
|
else
|
|
{
|
|
// load row by block size
|
|
}
|
|
}
|
|
}
|
|
__global__ void org_mm_cuda_share(const float *a, const float *b, float *c, int M, int N, int K)
|
|
{
|
|
__shared__ float a_shared[BLOCK_SIZE][BLOCK_SIZE];
|
|
__shared__ float b_shared[BLOCK_SIZE][BLOCK_SIZE];
|
|
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
|
|
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;
|
|
int src_len = M * K;
|
|
int dest_len = K * N;
|
|
load2shared(a, b, BLOCK_SIZE);
|
|
if (row < M && col < N)
|
|
{
|
|
float sum_value = 0.0;
|
|
// copy data to shared memory
|
|
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; i++)
|
|
{
|
|
int src_index = row * K + i * BLOCK_SIZE + threadIdx.x;
|
|
if (src_index < src_len)
|
|
a_shared[threadIdx.y][threadIdx.x] = a[src_index];
|
|
else
|
|
a_shared[threadIdx.y][threadIdx.x] = 0.0;
|
|
int dest_index = (i * BLOCK_SIZE + threadIdx.y) * N + col;
|
|
if (dest_index < dest_len)
|
|
b_shared[threadIdx.y][threadIdx.x] = b[dest_index];
|
|
else
|
|
b_shared[threadIdx.y][threadIdx.x] = 0;
|
|
// 32 * 32 thread 都已经完全设定好了。
|
|
__syncthreads();
|
|
// 每个thread 开始计算对应的知值。但是是不是还是有一些东西是重复的呢?
|
|
for (int j = 0; j < BLOCK_SIZE; j++)
|
|
{
|
|
sum_value += a_shared[threadIdx.y][j] * b_shared[j][threadIdx.x];
|
|
}
|
|
__syncthreads();
|
|
}
|
|
c[row * N + col] = sum_value;
|
|
}
|
|
}
|
|
|
|
// __device__ __half sigmoid(__half x)
|
|
// {
|
|
// return __hdiv(__float2half(1.0), __hadd(__float2half(1.0), __hexp(__hneg(x))));
|
|
// }
|
|
|
|
// template <int BLOCK_SIZE, int M, int N, int K>
|
|
// void dispatcher()
|
|
// {
|
|
// }
|
|
|
|
__global__ void org_mm_cuda_share_half(const __half *a, const __half *b, __half *c, int M, int N, int K)
|
|
{
|
|
__shared__ __half a_shared[BLOCK_SIZE][BLOCK_SIZE];
|
|
__shared__ __half b_shared[BLOCK_SIZE][BLOCK_SIZE];
|
|
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
|
|
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;
|
|
int src_len = M * K;
|
|
int dest_len = K * N;
|
|
__half zero_base = __float2half(0.0);
|
|
if (row < M && col < N)
|
|
{
|
|
half sum_value = __float2half(0.0);
|
|
// copy data to shared memory
|
|
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; i++)
|
|
{
|
|
int src_index = row * K + i * BLOCK_SIZE + threadIdx.x;
|
|
if (src_index < src_len)
|
|
a_shared[threadIdx.y][threadIdx.x] = a[src_index];
|
|
else
|
|
a_shared[threadIdx.y][threadIdx.x] = zero_base;
|
|
int dest_index = (i * BLOCK_SIZE + threadIdx.y) * N + col;
|
|
if (dest_index < dest_len)
|
|
b_shared[threadIdx.y][threadIdx.x] = b[dest_index];
|
|
else
|
|
b_shared[threadIdx.y][threadIdx.x] = zero_base;
|
|
// 32 * 32 thread 都已经完全设定好了。
|
|
__syncthreads();
|
|
// 每个thread 开始计算对应的知值。但是是不是还是有一些东西是重复的呢?
|
|
#pragma unroll
|
|
for (int j = 0; j < BLOCK_SIZE; j++)
|
|
{
|
|
sum_value = __hfma(a_shared[threadIdx.y][j], b_shared[j][threadIdx.x], sum_value);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
c[row * N + col] = sum_value;
|
|
}
|
|
}
|
|
|
|
void org_mm(const at::Tensor &a, const at::Tensor &b, at::Tensor &c)
|
|
{
|
|
int M = a.size(0);
|
|
int K = a.size(1);
|
|
int N = b.size(1);
|
|
// printf("block.x is %d, block.y is %d\n", (M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
|
|
|
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
|
|
dim3 block_dim((M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
|
org_mm_cuda<<<block_dim, threads>>>(a.data_ptr<float>(), b.data_ptr<float>(), c.data_ptr<float>(), M, N, K);
|
|
}
|
|
|
|
void org_mm_shared(const at::Tensor &a, const at::Tensor &b, at::Tensor &c)
|
|
{
|
|
int M = a.size(0);
|
|
int K = a.size(1);
|
|
int N = b.size(1);
|
|
// printf("block.x is %d, block.y is %d\n", (M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
|
|
|
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
|
|
dim3 block_dim((M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
|
org_mm_cuda_share<<<block_dim, threads>>>(a.data_ptr<float>(), b.data_ptr<float>(), c.data_ptr<float>(), M, N, K);
|
|
}
|
|
|
|
void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c)
|
|
{
|
|
int M = a.size(0);
|
|
int K = a.size(1);
|
|
int N = b.size(1);
|
|
// printf("block.x is %d, block.y is %d\n", (M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
|
|
|
dim3 threads(BLOCK_SIZE, BLOCK_SIZE);
|
|
dim3 block_dim((M + BLOCK_SIZE - 1) / BLOCK_SIZE, (N + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
|
org_mm_cuda_share_half<<<block_dim, threads>>>((__half *)a.data_ptr<at::Half>(),
|
|
(__half *)b.data_ptr<at::Half>(),
|
|
(__half *)c.data_ptr<at::Half>(),
|
|
M, N, K);
|
|
}
|
|
|
|
// __global__ void test_restrict(
|
|
// __restrict__ float *data_ptr)
|
|
// {
|
|
// }
|
|
|
|
// __global__ void test_const(const float *const_v1, float *const const_v2)
|
|
// {
|
|
// }
|
|
|
|
// __global__ void rmsnorm(const float *src)
|
|
// {
|
|
// }
|
|
|
|
__global__ void print_idx_kernel()
|
|
{
|
|
printf("this is blockdim.x %d\n", blockDim.x);
|
|
printf("this is blockidx.x %d\n", blockIdx.x);
|
|
printf("this is threadidx.x %d\n", threadIdx.x);
|
|
}
|
|
|
|
void print_idx()
|
|
{
|
|
int block_size = 256;
|
|
int data_len = 1023;
|
|
dim3 blockdim((data_len + block_size - 1) / block_size);
|
|
print_idx_kernel<<<blockdim, block_size>>>();
|
|
} |