594 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			594 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|  "cells": [
 | |
|   {
 | |
|    "attachments": {},
 | |
|    "cell_type": "markdown",
 | |
|    "id": "5d24a692",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "# Example of using elementwise activation functions in the CUTLASS Python interface\n",
 | |
|     "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n",
 | |
|     "\n",
 | |
|     "[](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "3ca993fe",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "We first import various packages needed for the example and construct the input and output tensors that will be used in our example."
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 1,
 | |
|    "id": "63a70a3c",
 | |
|    "metadata": {
 | |
|     "execution": {
 | |
|      "iopub.execute_input": "2023-04-18T18:00:09.148380Z",
 | |
|      "iopub.status.busy": "2023-04-18T18:00:09.148011Z",
 | |
|      "iopub.status.idle": "2023-04-18T18:00:13.281937Z",
 | |
|      "shell.execute_reply": "2023-04-18T18:00:13.281256Z"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stderr",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
 | |
|       "  from .autonotebook import tqdm as notebook_tqdm\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "import numpy as np\n",
 | |
|     "\n",
 | |
|     "import cutlass\n",
 | |
|     "\n",
 | |
|     "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
 | |
|     "# omit this information.\n",
 | |
|     "print_module = True\n",
 | |
|     "\n",
 | |
|     "m = 256\n",
 | |
|     "n = m\n",
 | |
|     "k = m\n",
 | |
|     "\n",
 | |
|     "type_A = np.float16\n",
 | |
|     "type_B = np.float16\n",
 | |
|     "type_C = np.float16\n",
 | |
|     "type_D = np.float16\n",
 | |
|     "\n",
 | |
|     "np.random.seed(1234)\n",
 | |
|     "scope_min = -4\n",
 | |
|     "scope_max = 4\n",
 | |
|     "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
 | |
|     "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
 | |
|     "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
 | |
|     "\n",
 | |
|     "alpha = np.float16(1.)\n",
 | |
|     "beta = np.float16(0.)\n",
 | |
|     "\n",
 | |
|     "tensor_D = np.zeros(tensor_C.shape).astype(type_D)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "1eb0d95b",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## Run a GEMM with an identity activation function\n",
 | |
|     "To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified."
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 2,
 | |
|    "id": "8d257833",
 | |
|    "metadata": {
 | |
|     "execution": {
 | |
|      "iopub.execute_input": "2023-04-18T18:00:13.284650Z",
 | |
|      "iopub.status.busy": "2023-04-18T18:00:13.284425Z",
 | |
|      "iopub.status.idle": "2023-04-18T18:00:18.333867Z",
 | |
|      "shell.execute_reply": "2023-04-18T18:00:18.333187Z"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "data": {
 | |
|       "text/plain": [
 | |
|        "<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed907287c0>"
 | |
|       ]
 | |
|      },
 | |
|      "execution_count": 2,
 | |
|      "metadata": {},
 | |
|      "output_type": "execute_result"
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n",
 | |
|     "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "54961694",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## Run a GEMM with a ReLU element-wise activation function\n",
 | |
|     "CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:\n",
 | |
|     "```\n",
 | |
|     "D = alpha * (A @ B) + beta * C\n",
 | |
|     "D = act(D)\n",
 | |
|     "```\n",
 | |
|     "\n",
 | |
|     "Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.\n",
 | |
|     "\n",
 | |
|     "This is easy to do in CUTLASS. One only needs to set the plan's `activation` field."
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 3,
 | |
|    "id": "5fe49443",
 | |
|    "metadata": {
 | |
|     "execution": {
 | |
|      "iopub.execute_input": "2023-04-18T18:00:18.337036Z",
 | |
|      "iopub.status.busy": "2023-04-18T18:00:18.336833Z",
 | |
|      "iopub.status.idle": "2023-04-18T18:00:23.482072Z",
 | |
|      "shell.execute_reply": "2023-04-18T18:00:23.481125Z"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "data": {
 | |
|       "text/plain": [
 | |
|        "<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed906f2460>"
 | |
|       ]
 | |
|      },
 | |
|      "execution_count": 3,
 | |
|      "metadata": {},
 | |
|      "output_type": "execute_result"
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n",
 | |
|     "plan.activation = cutlass.epilogue.relu\n",
 | |
|     "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "455d0a37",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "We can now verify that the result of the GEMM that used a ReLU activation function:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 4,
 | |
|    "id": "e32e7798",
 | |
|    "metadata": {
 | |
|     "execution": {
 | |
|      "iopub.execute_input": "2023-04-18T18:00:23.486042Z",
 | |
|      "iopub.status.busy": "2023-04-18T18:00:23.485342Z",
 | |
|      "iopub.status.idle": "2023-04-18T18:00:23.497444Z",
 | |
|      "shell.execute_reply": "2023-04-18T18:00:23.496668Z"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D\n",
 | |
|     "np.testing.assert_array_equal(relu_ref, tensor_D_relu)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "cf959171",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## Other element-wise activation functions\n",
 | |
|     "CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method."
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 5,
 | |
|    "id": "9e17d730",
 | |
|    "metadata": {
 | |
|     "execution": {
 | |
|      "iopub.execute_input": "2023-04-18T18:00:23.500102Z",
 | |
|      "iopub.status.busy": "2023-04-18T18:00:23.499944Z",
 | |
|      "iopub.status.idle": "2023-04-18T18:00:23.504562Z",
 | |
|      "shell.execute_reply": "2023-04-18T18:00:23.503793Z"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "<class 'cutlass.backend.epilogue.gelu'>\n",
 | |
|       "<class 'cutlass.backend.epilogue.hardswish'>\n",
 | |
|       "<class 'cutlass.backend.epilogue.identity'>\n",
 | |
|       "<class 'cutlass.backend.epilogue.leaky_relu'>\n",
 | |
|       "<class 'cutlass.backend.epilogue.relu'>\n",
 | |
|       "<class 'cutlass.backend.epilogue.sigmoid'>\n",
 | |
|       "<class 'cutlass.backend.epilogue.silu'>\n",
 | |
|       "<class 'cutlass.backend.epilogue.tanh'>\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "activations = plan.activations()\n",
 | |
|     "for activation in activations:\n",
 | |
|     "    print(activation)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "0e4599fa",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "We can then run each of them:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 6,
 | |
|    "id": "9c3598c9",
 | |
|    "metadata": {
 | |
|     "execution": {
 | |
|      "iopub.execute_input": "2023-04-18T18:00:23.507538Z",
 | |
|      "iopub.status.busy": "2023-04-18T18:00:23.507257Z",
 | |
|      "iopub.status.idle": "2023-04-18T18:00:59.414765Z",
 | |
|      "shell.execute_reply": "2023-04-18T18:00:59.414116Z"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.gelu'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.hardswish'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::HardSwish, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.identity'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n",
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.leaky_relu'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::LeakyReLU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.relu'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n",
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.sigmoid'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Sigmoid, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.silu'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::SiLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "=============================================================================================\n",
 | |
|       "Compiling and running activation <class 'cutlass.backend.epilogue.tanh'>\n",
 | |
|       "=============================================================================================\n",
 | |
|       "\n",
 | |
|       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
 | |
|       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
 | |
|       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
 | |
|       "    cutlass::half_t, cutlass::layout::RowMajor,\n",
 | |
|       "    cutlass::half_t,\n",
 | |
|       "    cutlass::arch::OpClassTensorOp,\n",
 | |
|       "    cutlass::arch::Sm80,\n",
 | |
|       "    cutlass::gemm::GemmShape<256, 128, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<64, 64, 64>,\n",
 | |
|       "    cutlass::gemm::GemmShape<16, 8, 16>,\n",
 | |
|       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Tanh, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
 | |
|       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
 | |
|       "    3,\n",
 | |
|       "    cutlass::arch::OpMultiplyAdd\n",
 | |
|       ">::GemmKernel;\n",
 | |
|       "\n",
 | |
|       "// Define named type\n",
 | |
|       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
 | |
|       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
 | |
|       "\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "for activation in activations:\n",
 | |
|     "    print('=============================================================================================')\n",
 | |
|     "    print(f'Compiling and running activation {activation}')\n",
 | |
|     "    print('=============================================================================================')\n",
 | |
|     "    plan.activation = activation\n",
 | |
|     "    plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "751f8d92",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": []
 | |
|   }
 | |
|  ],
 | |
|  "metadata": {
 | |
|   "kernelspec": {
 | |
|    "display_name": "Python 3 (ipykernel)",
 | |
|    "language": "python",
 | |
|    "name": "python3"
 | |
|   },
 | |
|   "language_info": {
 | |
|    "codemirror_mode": {
 | |
|     "name": "ipython",
 | |
|     "version": 3
 | |
|    },
 | |
|    "file_extension": ".py",
 | |
|    "mimetype": "text/x-python",
 | |
|    "name": "python",
 | |
|    "nbconvert_exporter": "python",
 | |
|    "pygments_lexer": "ipython3",
 | |
|    "version": "3.8.10"
 | |
|   }
 | |
|  },
 | |
|  "nbformat": 4,
 | |
|  "nbformat_minor": 5
 | |
| }
 | 
