594 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			594 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
|   | { | ||
|  |  "cells": [ | ||
|  |   { | ||
|  |    "attachments": {}, | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "5d24a692", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "# Example of using elementwise activation functions in the CUTLASS Python interface\n", | ||
|  |     "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n", | ||
|  |     "\n", | ||
|  |     "[](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "3ca993fe", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "We first import various packages needed for the example and construct the input and output tensors that will be used in our example." | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": 1, | ||
|  |    "id": "63a70a3c", | ||
|  |    "metadata": { | ||
|  |     "execution": { | ||
|  |      "iopub.execute_input": "2023-04-18T18:00:09.148380Z", | ||
|  |      "iopub.status.busy": "2023-04-18T18:00:09.148011Z", | ||
|  |      "iopub.status.idle": "2023-04-18T18:00:13.281937Z", | ||
|  |      "shell.execute_reply": "2023-04-18T18:00:13.281256Z" | ||
|  |     } | ||
|  |    }, | ||
|  |    "outputs": [ | ||
|  |     { | ||
|  |      "name": "stderr", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
|  |       "  from .autonotebook import tqdm as notebook_tqdm\n" | ||
|  |      ] | ||
|  |     } | ||
|  |    ], | ||
|  |    "source": [ | ||
|  |     "import numpy as np\n", | ||
|  |     "\n", | ||
|  |     "import cutlass\n", | ||
|  |     "\n", | ||
|  |     "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", | ||
|  |     "# omit this information.\n", | ||
|  |     "print_module = True\n", | ||
|  |     "\n", | ||
|  |     "m = 256\n", | ||
|  |     "n = m\n", | ||
|  |     "k = m\n", | ||
|  |     "\n", | ||
|  |     "type_A = np.float16\n", | ||
|  |     "type_B = np.float16\n", | ||
|  |     "type_C = np.float16\n", | ||
|  |     "type_D = np.float16\n", | ||
|  |     "\n", | ||
|  |     "np.random.seed(1234)\n", | ||
|  |     "scope_min = -4\n", | ||
|  |     "scope_max = 4\n", | ||
|  |     "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", | ||
|  |     "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", | ||
|  |     "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", | ||
|  |     "\n", | ||
|  |     "alpha = np.float16(1.)\n", | ||
|  |     "beta = np.float16(0.)\n", | ||
|  |     "\n", | ||
|  |     "tensor_D = np.zeros(tensor_C.shape).astype(type_D)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "1eb0d95b", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "## Run a GEMM with an identity activation function\n", | ||
|  |     "To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified." | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": 2, | ||
|  |    "id": "8d257833", | ||
|  |    "metadata": { | ||
|  |     "execution": { | ||
|  |      "iopub.execute_input": "2023-04-18T18:00:13.284650Z", | ||
|  |      "iopub.status.busy": "2023-04-18T18:00:13.284425Z", | ||
|  |      "iopub.status.idle": "2023-04-18T18:00:18.333867Z", | ||
|  |      "shell.execute_reply": "2023-04-18T18:00:18.333187Z" | ||
|  |     } | ||
|  |    }, | ||
|  |    "outputs": [ | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     }, | ||
|  |     { | ||
|  |      "data": { | ||
|  |       "text/plain": [ | ||
|  |        "<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed907287c0>" | ||
|  |       ] | ||
|  |      }, | ||
|  |      "execution_count": 2, | ||
|  |      "metadata": {}, | ||
|  |      "output_type": "execute_result" | ||
|  |     } | ||
|  |    ], | ||
|  |    "source": [ | ||
|  |     "plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n", | ||
|  |     "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "54961694", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "## Run a GEMM with a ReLU element-wise activation function\n", | ||
|  |     "CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:\n", | ||
|  |     "```\n", | ||
|  |     "D = alpha * (A @ B) + beta * C\n", | ||
|  |     "D = act(D)\n", | ||
|  |     "```\n", | ||
|  |     "\n", | ||
|  |     "Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.\n", | ||
|  |     "\n", | ||
|  |     "This is easy to do in CUTLASS. One only needs to set the plan's `activation` field." | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": 3, | ||
|  |    "id": "5fe49443", | ||
|  |    "metadata": { | ||
|  |     "execution": { | ||
|  |      "iopub.execute_input": "2023-04-18T18:00:18.337036Z", | ||
|  |      "iopub.status.busy": "2023-04-18T18:00:18.336833Z", | ||
|  |      "iopub.status.idle": "2023-04-18T18:00:23.482072Z", | ||
|  |      "shell.execute_reply": "2023-04-18T18:00:23.481125Z" | ||
|  |     } | ||
|  |    }, | ||
|  |    "outputs": [ | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     }, | ||
|  |     { | ||
|  |      "data": { | ||
|  |       "text/plain": [ | ||
|  |        "<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed906f2460>" | ||
|  |       ] | ||
|  |      }, | ||
|  |      "execution_count": 3, | ||
|  |      "metadata": {}, | ||
|  |      "output_type": "execute_result" | ||
|  |     } | ||
|  |    ], | ||
|  |    "source": [ | ||
|  |     "tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n", | ||
|  |     "plan.activation = cutlass.epilogue.relu\n", | ||
|  |     "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "455d0a37", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "We can now verify that the result of the GEMM that used a ReLU activation function:" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": 4, | ||
|  |    "id": "e32e7798", | ||
|  |    "metadata": { | ||
|  |     "execution": { | ||
|  |      "iopub.execute_input": "2023-04-18T18:00:23.486042Z", | ||
|  |      "iopub.status.busy": "2023-04-18T18:00:23.485342Z", | ||
|  |      "iopub.status.idle": "2023-04-18T18:00:23.497444Z", | ||
|  |      "shell.execute_reply": "2023-04-18T18:00:23.496668Z" | ||
|  |     } | ||
|  |    }, | ||
|  |    "outputs": [], | ||
|  |    "source": [ | ||
|  |     "relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D\n", | ||
|  |     "np.testing.assert_array_equal(relu_ref, tensor_D_relu)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "cf959171", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "## Other element-wise activation functions\n", | ||
|  |     "CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method." | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": 5, | ||
|  |    "id": "9e17d730", | ||
|  |    "metadata": { | ||
|  |     "execution": { | ||
|  |      "iopub.execute_input": "2023-04-18T18:00:23.500102Z", | ||
|  |      "iopub.status.busy": "2023-04-18T18:00:23.499944Z", | ||
|  |      "iopub.status.idle": "2023-04-18T18:00:23.504562Z", | ||
|  |      "shell.execute_reply": "2023-04-18T18:00:23.503793Z" | ||
|  |     } | ||
|  |    }, | ||
|  |    "outputs": [ | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "<class 'cutlass.backend.epilogue.gelu'>\n", | ||
|  |       "<class 'cutlass.backend.epilogue.hardswish'>\n", | ||
|  |       "<class 'cutlass.backend.epilogue.identity'>\n", | ||
|  |       "<class 'cutlass.backend.epilogue.leaky_relu'>\n", | ||
|  |       "<class 'cutlass.backend.epilogue.relu'>\n", | ||
|  |       "<class 'cutlass.backend.epilogue.sigmoid'>\n", | ||
|  |       "<class 'cutlass.backend.epilogue.silu'>\n", | ||
|  |       "<class 'cutlass.backend.epilogue.tanh'>\n" | ||
|  |      ] | ||
|  |     } | ||
|  |    ], | ||
|  |    "source": [ | ||
|  |     "activations = plan.activations()\n", | ||
|  |     "for activation in activations:\n", | ||
|  |     "    print(activation)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "markdown", | ||
|  |    "id": "0e4599fa", | ||
|  |    "metadata": {}, | ||
|  |    "source": [ | ||
|  |     "We can then run each of them:" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": 6, | ||
|  |    "id": "9c3598c9", | ||
|  |    "metadata": { | ||
|  |     "execution": { | ||
|  |      "iopub.execute_input": "2023-04-18T18:00:23.507538Z", | ||
|  |      "iopub.status.busy": "2023-04-18T18:00:23.507257Z", | ||
|  |      "iopub.status.idle": "2023-04-18T18:00:59.414765Z", | ||
|  |      "shell.execute_reply": "2023-04-18T18:00:59.414116Z" | ||
|  |     } | ||
|  |    }, | ||
|  |    "outputs": [ | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.gelu'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     }, | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.hardswish'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::HardSwish, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     }, | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.identity'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.leaky_relu'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::LeakyReLU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     }, | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.relu'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.sigmoid'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Sigmoid, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     }, | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.silu'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::SiLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     }, | ||
|  |     { | ||
|  |      "name": "stdout", | ||
|  |      "output_type": "stream", | ||
|  |      "text": [ | ||
|  |       "=============================================================================================\n", | ||
|  |       "Compiling and running activation <class 'cutlass.backend.epilogue.tanh'>\n", | ||
|  |       "=============================================================================================\n", | ||
|  |       "\n", | ||
|  |       "// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n", | ||
|  |       "using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n", | ||
|  |       "  typename cutlass::gemm::kernel::DefaultGemmUniversal<\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n", | ||
|  |       "    cutlass::half_t, cutlass::layout::RowMajor,\n", | ||
|  |       "    cutlass::half_t,\n", | ||
|  |       "    cutlass::arch::OpClassTensorOp,\n", | ||
|  |       "    cutlass::arch::Sm80,\n", | ||
|  |       "    cutlass::gemm::GemmShape<256, 128, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<64, 64, 64>,\n", | ||
|  |       "    cutlass::gemm::GemmShape<16, 8, 16>,\n", | ||
|  |       "    cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Tanh, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n", | ||
|  |       "    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n", | ||
|  |       "    3,\n", | ||
|  |       "    cutlass::arch::OpMultiplyAdd\n", | ||
|  |       ">::GemmKernel;\n", | ||
|  |       "\n", | ||
|  |       "// Define named type\n", | ||
|  |       "struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n", | ||
|  |       "  public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n", | ||
|  |       "\n" | ||
|  |      ] | ||
|  |     } | ||
|  |    ], | ||
|  |    "source": [ | ||
|  |     "for activation in activations:\n", | ||
|  |     "    print('=============================================================================================')\n", | ||
|  |     "    print(f'Compiling and running activation {activation}')\n", | ||
|  |     "    print('=============================================================================================')\n", | ||
|  |     "    plan.activation = activation\n", | ||
|  |     "    plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" | ||
|  |    ] | ||
|  |   }, | ||
|  |   { | ||
|  |    "cell_type": "code", | ||
|  |    "execution_count": null, | ||
|  |    "id": "751f8d92", | ||
|  |    "metadata": {}, | ||
|  |    "outputs": [], | ||
|  |    "source": [] | ||
|  |   } | ||
|  |  ], | ||
|  |  "metadata": { | ||
|  |   "kernelspec": { | ||
|  |    "display_name": "Python 3 (ipykernel)", | ||
|  |    "language": "python", | ||
|  |    "name": "python3" | ||
|  |   }, | ||
|  |   "language_info": { | ||
|  |    "codemirror_mode": { | ||
|  |     "name": "ipython", | ||
|  |     "version": 3 | ||
|  |    }, | ||
|  |    "file_extension": ".py", | ||
|  |    "mimetype": "text/x-python", | ||
|  |    "name": "python", | ||
|  |    "nbconvert_exporter": "python", | ||
|  |    "pygments_lexer": "ipython3", | ||
|  |    "version": "3.8.10" | ||
|  |   } | ||
|  |  }, | ||
|  |  "nbformat": 4, | ||
|  |  "nbformat_minor": 5 | ||
|  | } |