"[](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/02_pytorch_extension_grouped_gemm.ipynb)\n"
"This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc7c7458",
"metadata": {},
"outputs": [],
"source": [
"!#nvidia-smi"
]
},
{
"cell_type": "markdown",
"id": "2107bb0d",
"metadata": {},
"source": [
"If running on Colab, you will need to install the CUTLASS Python interface and PyTorch. To do so, uncomment the following line and run the cell:"
"## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n",
"The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n",
"\n",
"The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n",
"\n",
"To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:"
"Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n",
"\n",
"See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n",
"\n",
"The PyTorch CUDA extension could be built for this module by running:\n",
"```bash\n",
"cd out\n",
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n",
"```\n",
"(assuming that one is building for SM80)\n",
"\n",
"One could then use the kernel in a later PyTorch module by running:\n",
"\n",
"```python\n",
"import torch\n",
"import grouped_gemm\n",
"\n",
"grouped_gemm.run(As, Bs)\n",
"```\n",
"\n",
"In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n",
"Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n",
"and returns back the loaded extension.\n",
"\n",
"We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cecb26a4",
"metadata": {},
"outputs": [],
"source": [
"Ds = grouped_gemm.run(As, Bs)\n",
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
"for d, d_torch in zip(Ds, Ds_torch):\n",
" assert torch.allclose(d, d_torch)"
]
},
{
"cell_type": "markdown",
"id": "50db80e4",
"metadata": {},
"source": [
"Finally, we can profile our grouped GEMM extension:"