#include #include #include using namespace nvcuda; #define WMMA_M 16 #define WMMA_N 16 #define WMMA_K 16 // CUDA kernel for matrix multiplication using Tensor Cores __global__ void matrixMulKernel(half *d_C, const half *d_A, const half *d_B, int m, int n, int k) { // Declare the fragments wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; // Initialize the output fragment to zero wmma::fill_fragment(c_frag, 0.0f); // Load the input matrices into fragments for (int i = 0; i < (n + WMMA_K - 1) / WMMA_K; ++i) { int a_row = blockIdx.y * WMMA_M + threadIdx.y; int a_col = i * WMMA_K + threadIdx.x; int b_row = i * WMMA_K + threadIdx.y; int b_col = blockIdx.x * WMMA_N + threadIdx.x; if (a_row < m && a_col < n) { a_frag.x[0] = d_A[a_row * n + a_col]; } else { a_frag.x[0] = 0.0f; } if (b_row < n && b_col < k) { b_frag.x[0] = d_B[b_row * k + b_col]; } else { b_frag.x[0] = 0.0f; } // Synchronize to make sure the fragments are loaded __syncthreads(); // Perform the matrix multiplication wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); // Synchronize to make sure that the preceding computation is done before loading new fragments __syncthreads(); } // Store the result to the global memory int c_row = blockIdx.y * WMMA_M + threadIdx.y; int c_col = blockIdx.x * WMMA_N + threadIdx.x; if (c_row < m && c_col < k) { d_C[c_row * k + c_col] = c_frag.x[0]; } } // Host code to initialize matrices and launch the kernel void matrixMul(half *h_C, const half *h_A, const half *h_B, int m, int n, int k) { // Allocate device memory for matrices A, B, and C half *d_A, *d_B, *d_C; cudaMalloc(&d_A, m * n * sizeof(half)); cudaMalloc(&d_B, n * k * sizeof(half)); cudaMalloc(&d_C, m * k * sizeof(half)); // Copy host data to device cudaMemcpy(d_A, h_A, m * n * sizeof(half), cudaMemcpyHostToDevice); cudaMemcpy(d_B, h_B, n * k * sizeof(half), cudaMemcpyHostToDevice); // Define grid and block dimensions dim3 dimBlock(WMMA_M, WMMA_N); dim3 dimGrid((k + WMMA_N - 1) / WMMA_N, (m + WMMA_M - 1) / WMMA_M); // Launch the kernel matrixMulKernel<<>>(d_C, d_A, d_B, m, n, k); // Copy result from device to host cudaMemcpy(h_C, d_C, m * k * sizeof(half), cudaMemcpyDeviceToHost); // Free device memory cudaFree(d_A); cudaFree(d_B); cudaFree(d_C); } // Helper function to print matrix void printMatrix(half *matrix, int rows, int cols) { for (int i = 0; i < rows; ++i) { for (int j = 0; j < cols; ++j) { std::cout << static_cast(matrix[i * cols + j]) << " "; } std::cout << std::endl; } } int main() { const int m = 4; const int n = 4; const int k = 4; // Initialize host matrices half h_A[m * n] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; half h_B[n * k] = {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}; half h_C[m * k]; // Perform matrix multiplication matrixMul(h_C, h_A, h_B, m, n, k); // Print the result std::cout << "Matrix A:" << std::endl; printMatrix(h_A, m, n); std::cout << "Matrix B:" << std::endl; printMatrix(h_B, n, k); std::cout << "Matrix C (A * B):" << std::endl; printMatrix(h_C, m, k); return 0; }