"We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2a7371e",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"random.seed(2023)\n",
"\n",
"# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n",
"def initialize(dtype, M, N, K):\n",
" sizes = [(M, K), (K, N), (M, N), (M, N)]\n",
" return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n",
"\n",
"# Utility function to generate `problems` GEMMs of random sizes\n",
"def generate_problems(problems):\n",
" valid_sizes = [128, 256, 512, 1024]\n",
" As, Bs, Cs, Ds = [], [], [], []\n",
" for _ in range(problems):\n",
" M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n",
" A, B, C, D = initialize(dtype, M, N, K)\n",
" As.append(A)\n",
" Bs.append(B)\n",
" Cs.append(C)\n",
" Ds.append(D)\n",
" return As, Bs, Cs, Ds"
]
},
{
"cell_type": "markdown",
"id": "590a3bc5",
"metadata": {},
"source": [
"We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "776c9233",
"metadata": {},
"outputs": [],
"source": [
"As, Bs, Cs, Ds, = generate_problems(50)\n",
"\n",
"plan.run(As, Bs, Cs, Ds, print_module=True)\n",
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
"\n",
"for d, d_torch in zip(Ds, Ds_torch):\n",
" assert torch.allclose(d, d_torch)"
]
},
{
"cell_type": "markdown",
"id": "766e4f03",
"metadata": {},
"source": [
"## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n",
"The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n",
"\n",
"The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n",
"\n",
"To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:"
"Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n",
"\n",
"See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n",
"\n",
"The PyTorch CUDA extension could be built for this module by running:\n",
"```bash\n",
"cd out\n",
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n",
"```\n",
"(assuming that one is building for SM80)\n",
"\n",
"One could then use the kernel in a later PyTorch module by running:\n",
"\n",
"```python\n",
"import torch\n",
"import grouped_gemm\n",
"\n",
"grouped_gemm.run(As, Bs)\n",
"```\n",
"\n",
"In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n",
"Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n",
"and returns back the loaded extension.\n",
"\n",
"We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cecb26a4",
"metadata": {},
"outputs": [],
"source": [
"Ds = grouped_gemm.run(As, Bs)\n",
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
"for d, d_torch in zip(Ds, Ds_torch):\n",
" assert torch.allclose(d, d_torch)"
]
},
{
"cell_type": "markdown",
"id": "50db80e4",
"metadata": {},
"source": [
"Finally, we can profile our grouped GEMM extension:"