265 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			265 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
|   | { | ||
|  |  "cells": [ | ||
|  |   { | ||
|  |    "attachments": {}, | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "6acbea5d", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n", | ||
|  |     "This notebook walks through a basic example of using the CUTLASS Python interface to declare\n", | ||
|  |     "a grouped GEMM kernel and export it as a PyTorch CUDA extension.\n", | ||
|  |     "\n", | ||
|  |     "[](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n", | ||
|  |     "\n", | ||
|  |     "## Background on grouped GEMM\n", | ||
|  |     "Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n", | ||
|  |     "in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n", | ||
|  |     "without the requirement that the sizes and strides of each GEMM be the same.\n", | ||
|  |     "\n", | ||
|  |     "For example, if one has `p` GEMMs with sizes:\n", | ||
|  |     "```text\n", | ||
|  |     "M_1 x N_1 x K_1\n", | ||
|  |     "M_2 x N_2 x K_2\n", | ||
|  |     "...\n", | ||
|  |     "M_p x N_p x K_p\n", | ||
|  |     "```\n", | ||
|  |     "CUTLASS's grouped GEMM will execute these in a single CUDA kernel.\n", | ||
|  |     "\n", | ||
|  |     "Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would\n", | ||
|  |     "insufficiently utilize the device in isolation.\n", | ||
|  |     "\n", | ||
|  |     "## Declaring a grouped GEMM via the CUTLASS Python interface\n", | ||
|  |     "A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n", | ||
|  |     "simply calls `cutlass.op.GroupedGemm`." | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "fdcf21d8", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [ | ||
|  |     "import cutlass\n", | ||
|  |     "import torch\n", | ||
|  |     "\n", | ||
|  |     "dtype = torch.float16\n", | ||
|  |     "plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "514f40a4", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs." | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "c2a7371e", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [ | ||
|  |     "import random\n", | ||
|  |     "random.seed(2023)\n", | ||
|  |     "\n", | ||
|  |     "# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n", | ||
|  |     "def initialize(dtype, M, N, K):\n", | ||
|  |     "    sizes = [(M, K), (K, N), (M, N), (M, N)]\n", | ||
|  |     "    return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n", | ||
|  |     "\n", | ||
|  |     "# Utility function to generate `problems` GEMMs of random sizes\n", | ||
|  |     "def generate_problems(problems):\n", | ||
|  |     "    valid_sizes = [128, 256, 512, 1024]\n", | ||
|  |     "    As, Bs, Cs, Ds = [], [], [], []\n", | ||
|  |     "    for _ in range(problems):\n", | ||
|  |     "        M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n", | ||
|  |     "        A, B, C, D = initialize(dtype, M, N, K)\n", | ||
|  |     "        As.append(A)\n", | ||
|  |     "        Bs.append(B)\n", | ||
|  |     "        Cs.append(C)\n", | ||
|  |     "        Ds.append(D)\n", | ||
|  |     "    return As, Bs, Cs, Ds" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "590a3bc5", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch." | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "776c9233", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [ | ||
|  |     "As, Bs, Cs, Ds, = generate_problems(50)\n", | ||
|  |     "\n", | ||
|  |     "plan.run(As, Bs, Cs, Ds, print_module=True)\n", | ||
|  |     "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", | ||
|  |     "\n", | ||
|  |     "for d, d_torch in zip(Ds, Ds_torch):\n", | ||
|  |     "    assert torch.allclose(d, d_torch)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "766e4f03", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n", | ||
|  |     "The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n", | ||
|  |     "\n", | ||
|  |     "The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n", | ||
|  |     "\n", | ||
|  |     "To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "3a98dee6", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [ | ||
|  |     "op = plan.construct()\n", | ||
|  |     "grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "c8ca3991", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "The `cutlass.emit.pytorch` function emits:\n", | ||
|  |     "* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n", | ||
|  |     "* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n", | ||
|  |     "* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n", | ||
|  |     "\n", | ||
|  |     "The extension can be build from within the `module_output` directory by running:\n", | ||
|  |     "```bash\n", | ||
|  |     "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py install\n", | ||
|  |     "```\n", | ||
|  |     "Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n", | ||
|  |     "\n", | ||
|  |     "See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n", | ||
|  |     "\n", | ||
|  |     "The PyTorch CUDA extension could be built for this module by running:\n", | ||
|  |     "```bash\n", | ||
|  |     "cd out\n", | ||
|  |     "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n", | ||
|  |     "```\n", | ||
|  |     "(assuming that one is building for SM80)\n", | ||
|  |     "\n", | ||
|  |     "One could then use the kernel in a later PyTorch module by running:\n", | ||
|  |     "\n", | ||
|  |     "```python\n", | ||
|  |     "import torch\n", | ||
|  |     "import grouped_gemm\n", | ||
|  |     "\n", | ||
|  |     "grouped_gemm.run(As, Bs)\n", | ||
|  |     "```\n", | ||
|  |     "\n", | ||
|  |     "In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n", | ||
|  |     "Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n", | ||
|  |     "and returns back the loaded extension.\n", | ||
|  |     "\n", | ||
|  |     "We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "cecb26a4", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [ | ||
|  |     "Ds = grouped_gemm.run(As, Bs)\n", | ||
|  |     "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", | ||
|  |     "for d, d_torch in zip(Ds, Ds_torch):\n", | ||
|  |     "    assert torch.allclose(d, d_torch)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "50db80e4", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "Finally, we can profile our grouped GEMM extension:" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "b76805d3", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [ | ||
|  |     "num_warmup = 20\n", | ||
|  |     "num_profile = 100\n", | ||
|  |     "\n", | ||
|  |     "# Warmup iterations\n", | ||
|  |     "for _ in range(num_warmup):\n", | ||
|  |     "    Ds = grouped_gemm.run(As, Bs)\n", | ||
|  |     "    Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", | ||
|  |     "    torch.cuda.synchronize()\n", | ||
|  |     "\n", | ||
|  |     "# Timing iterations\n", | ||
|  |     "import time\n", | ||
|  |     "grouped = 0\n", | ||
|  |     "nongrouped = 0\n", | ||
|  |     "for _ in range(num_profile):\n", | ||
|  |     "    start = time.time()\n", | ||
|  |     "    Ds = grouped_gemm.run(As, Bs)\n", | ||
|  |     "    torch.cuda.synchronize()\n", | ||
|  |     "    grouped += time.time() - start\n", | ||
|  |     "\n", | ||
|  |     "    start = time.time()\n", | ||
|  |     "    Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", | ||
|  |     "    torch.cuda.synchronize()\n", | ||
|  |     "    nongrouped += time.time() - start\n", | ||
|  |     "\n", | ||
|  |     "print('Grouped:     {:.3f} us'.format(grouped * 1e6/num_profile))\n", | ||
|  |     "print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n", | ||
|  |     "print('Speedup: {:.3f}'.format(nongrouped / grouped))" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "f22fc696", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [] | ||
|  |   } | ||
|  |  ], | ||
|  |  "metadata": { | ||
|  |   "kernelspec": { | ||
|  |    "display_name": "Python 3 (ipykernel)", | ||
|  |    "language": "python", | ||
|  |    "name": "python3" | ||
|  |   }, | ||
|  |   "language_info": { | ||
|  |    "codemirror_mode": { | ||
|  |     "name": "ipython", | ||
|  |     "version": 3 | ||
|  |    }, | ||
|  |    "file_extension": ".py", | ||
|  |    "mimetype": "text/x-python", | ||
|  |    "name": "python", | ||
|  |    "nbconvert_exporter": "python", | ||
|  |    "pygments_lexer": "ipython3", | ||
|  |    "version": "3.8.10" | ||
|  |   } | ||
|  |  }, | ||
|  |  "nbformat": 4, | ||
|  |  "nbformat_minor": 5 | ||
|  | } |