torch_ext/csrc/wmma_cu.cu
2024-11-16 19:26:54 +08:00

132 lines
3.7 KiB
Plaintext

#include <iostream>
#include <cuda_runtime.h>
#include <mma.h>
using namespace nvcuda;
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
// CUDA kernel for matrix multiplication using Tensor Cores
__global__ void matrixMulKernel(half *d_C, const half *d_A, const half *d_B, int m, int n, int k)
{
// Declare the fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> c_frag;
// Initialize the output fragment to zero
wmma::fill_fragment(c_frag, 0.0f);
// Load the input matrices into fragments
for (int i = 0; i < (n + WMMA_K - 1) / WMMA_K; ++i)
{
int a_row = blockIdx.y * WMMA_M + threadIdx.y;
int a_col = i * WMMA_K + threadIdx.x;
int b_row = i * WMMA_K + threadIdx.y;
int b_col = blockIdx.x * WMMA_N + threadIdx.x;
if (a_row < m && a_col < n)
{
a_frag.x[0] = d_A[a_row * n + a_col];
}
else
{
a_frag.x[0] = 0.0f;
}
if (b_row < n && b_col < k)
{
b_frag.x[0] = d_B[b_row * k + b_col];
}
else
{
b_frag.x[0] = 0.0f;
}
// Synchronize to make sure the fragments are loaded
__syncthreads();
// Perform the matrix multiplication
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
// Synchronize to make sure that the preceding computation is done before loading new fragments
__syncthreads();
}
// Store the result to the global memory
int c_row = blockIdx.y * WMMA_M + threadIdx.y;
int c_col = blockIdx.x * WMMA_N + threadIdx.x;
if (c_row < m && c_col < k)
{
d_C[c_row * k + c_col] = c_frag.x[0];
}
}
// Host code to initialize matrices and launch the kernel
void matrixMul(half *h_C, const half *h_A, const half *h_B, int m, int n, int k)
{
// Allocate device memory for matrices A, B, and C
half *d_A, *d_B, *d_C;
cudaMalloc(&d_A, m * n * sizeof(half));
cudaMalloc(&d_B, n * k * sizeof(half));
cudaMalloc(&d_C, m * k * sizeof(half));
// Copy host data to device
cudaMemcpy(d_A, h_A, m * n * sizeof(half), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, h_B, n * k * sizeof(half), cudaMemcpyHostToDevice);
// Define grid and block dimensions
dim3 dimBlock(WMMA_M, WMMA_N);
dim3 dimGrid((k + WMMA_N - 1) / WMMA_N, (m + WMMA_M - 1) / WMMA_M);
// Launch the kernel
matrixMulKernel<<<dimGrid, dimBlock>>>(d_C, d_A, d_B, m, n, k);
// Copy result from device to host
cudaMemcpy(h_C, d_C, m * k * sizeof(half), cudaMemcpyDeviceToHost);
// Free device memory
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
}
// Helper function to print matrix
void printMatrix(half *matrix, int rows, int cols)
{
for (int i = 0; i < rows; ++i)
{
for (int j = 0; j < cols; ++j)
{
std::cout << static_cast<float>(matrix[i * cols + j]) << " ";
}
std::cout << std::endl;
}
}
int main()
{
const int m = 4;
const int n = 4;
const int k = 4;
// Initialize host matrices
half h_A[m * n] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
half h_B[n * k] = {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1};
half h_C[m * k];
// Perform matrix multiplication
matrixMul(h_C, h_A, h_B, m, n, k);
// Print the result
std::cout << "Matrix A:" << std::endl;
printMatrix(h_A, m, n);
std::cout << "Matrix B:" << std::endl;
printMatrix(h_B, n, k);
std::cout << "Matrix C (A * B):" << std::endl;
printMatrix(h_C, m, k);
return 0;
}