329 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			329 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 | |
| #
 | |
| # Redistribution and use in source and binary forms, with or without modification, are permitted
 | |
| # provided that the following conditions are met:
 | |
| #     * Redistributions of source code must retain the above copyright notice, this list of
 | |
| #       conditions and the following disclaimer.
 | |
| #     * Redistributions in binary form must reproduce the above copyright notice, this list of
 | |
| #       conditions and the following disclaimer in the documentation and/or other materials
 | |
| #       provided with the distribution.
 | |
| #     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
 | |
| #       to endorse or promote products derived from this software without specific prior written
 | |
| #       permission.
 | |
| #
 | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
 | |
| # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 | |
| # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
 | |
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 | |
| # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 | |
| # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 | |
| # STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
| 
 | |
| # this file creates the test/unit/gemm/device simt tests
 | |
| 
 | |
| 
 | |
| outputDir = ""
 | |
| 
 | |
| ################################################################################
 | |
| # parameters
 | |
| # Edge - for tiles, the edges represent the length of one side
 | |
| # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
 | |
| # MaxEdge - maximum length of each edge
 | |
| # Min/Max - minimum/maximum of the product of edge lengths
 | |
| ################################################################################
 | |
| 
 | |
| warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
 | |
| warpsPerThreadblockRatio = 2
 | |
| warpsPerThreadblockMax = 16
 | |
| # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
 | |
| 
 | |
| warpShapeEdges = [8, 16, 32, 64, 128, 256]
 | |
| warpShapeRatio = 4
 | |
| warpShapeMax = 64*64
 | |
| warpShapeMin = 8*8
 | |
| 
 | |
| threadblockEdgeMax = 256
 | |
| 
 | |
| #      char,      type             bits/elem, max tile,    L0 threadblock tiles
 | |
| precisions = [
 | |
|        ["c", "cutlass::complex<float>",   64,  64*128, [ [ 64, 128], [ 64,  32]             ] ],
 | |
|        ["d", "double",                    64,   64*64, [ [ 64,  64], [ 32,  32]             ] ],
 | |
|        ["h", "cutlass::half_t",           16, 128*256, [ [256, 128], [ 64, 128], [ 64,  32] ] ],
 | |
|        ["i", "int",                       32, 128*128, [ [128,  64], [ 16, 32]              ] ],
 | |
|        ["s", "float",                     32, 128*128, [ [128, 256], [128, 128], [ 64,  64] ] ],
 | |
|        ["z", "cutlass::complex<double>", 128,   64*64, [ [ 32,  64], [ 16,  32]             ] ],
 | |
|        ]
 | |
| # L1 will have a single kernel for every unique shape
 | |
| # L2 will have everything else
 | |
| 
 | |
| transposes = [
 | |
|        [False, False],
 | |
|        [False, True],
 | |
|        [True, False],
 | |
|        [True, True]
 | |
|        ]
 | |
| 
 | |
| ################################################################################
 | |
| # warps per threadblock
 | |
| ################################################################################
 | |
| warpsPerThreadblocks = []
 | |
| for warpsPerThreadblock0 in warpsPerThreadblockEdge:
 | |
|     for warpsPerThreadblock1 in warpsPerThreadblockEdge:
 | |
|         if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax:
 | |
|             warpsPerThreadblocks.append([warpsPerThreadblock0,
 | |
|                 warpsPerThreadblock1])
 | |
| print("WarpsPerThreadblocks",warpsPerThreadblocks)
 | |
| 
 | |
| ################################################################################
 | |
| # warp shapes
 | |
| ################################################################################
 | |
| warpNumThreads = 32
 | |
| warpShapes = []
 | |
| for warp0 in warpShapeEdges:
 | |
|     for warp1 in warpShapeEdges:
 | |
|         if warp0 / warp1 <= warpShapeRatio and warp1 / warp0 <= warpShapeRatio and warp0*warp1 <= warpShapeMax and warp0*warp1 > warpShapeMin:
 | |
|             warpShapes.append([warp0, warp1])
 | |
| print("WarpShapes", warpShapes)
 | |
| 
 | |
| numL0 = 0
 | |
| numL1 = 0
 | |
| numL2 = 0
 | |
| 
 | |
| ################################################################################
 | |
| # create kernels
 | |
| # create a file for each precision/transpose
 | |
| # each file contains many tile sizes
 | |
| ################################################################################
 | |
| 
 | |
| # precisions
 | |
| for precision in precisions:
 | |
| 
 | |
|     # get precision char
 | |
|     precisionChar = precision[0]
 | |
|     precisionType = precision[1]
 | |
|     precisionBits = precision[2]
 | |
|     threadblockMaxElements = precision[3]
 | |
|     threadblockTilesL0 = precision[4]
 | |
| 
 | |
|     # transposes
 | |
|     for transpose in transposes:
 | |
| 
 | |
|         # get transpose char
 | |
|         columnMajorA = transpose[0]
 | |
|         columnMajorB = transpose[1]
 | |
|         transCharA = "n" if columnMajorA else "t"
 | |
|         transCharB = "n" if columnMajorB else "t"
 | |
| 
 | |
|         # open file
 | |
|         fileName="simt_%sgemm_%s%s_sm50.cu" % (precisionChar, transCharA, transCharB)
 | |
|         print("\n", fileName)
 | |
|         filePath = "%s%s" % (outputDir, fileName)
 | |
|         out = open(filePath, "w+")
 | |
| 
 | |
|         # write file header
 | |
|         out.write("/***************************************************************************************************\n"
 | |
| " * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.\n"
 | |
| " *\n"
 | |
| " * Redistribution and use in source and binary forms, with or without modification, are permitted\n"
 | |
| " * provided that the following conditions are met:\n"
 | |
| " *     * Redistributions of source code must retain the above copyright notice, this list of\n"
 | |
| " *       conditions and the following disclaimer.\n"
 | |
| " *     * Redistributions in binary form must reproduce the above copyright notice, this list of\n"
 | |
| " *       conditions and the following disclaimer in the documentation and/or other materials\n"
 | |
| " *       provided with the distribution.\n"
 | |
| " *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used\n"
 | |
| " *       to endorse or promote products derived from this software without specific prior written\n"
 | |
| " *       permission.\n"
 | |
| " *\n"
 | |
| " * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR\n"
 | |
| " * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND\n"
 | |
| " * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE\n"
 | |
| " * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,\n"
 | |
| " * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;\n"
 | |
| " * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,\n"
 | |
| " * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n"
 | |
| " * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
 | |
| " *\n"
 | |
| " **************************************************************************************************/\n"
 | |
| "/*! \\file\n"
 | |
| "    \\brief Tests for device-wide GEMM interface\n"
 | |
| "*/\n"
 | |
| "\n"
 | |
| "#include <iostream>\n"
 | |
| "\n"
 | |
| "#include \"cutlass/cutlass.h\"\n"
 | |
| "#include \"cutlass/gemm/device/gemm.h\"\n"
 | |
| "#include \"cutlass/numeric_types.h\"\n"
 | |
| "\n"
 | |
| "#include \"../../common/cutlass_unit_test.h\"\n"
 | |
| "\n"
 | |
| "#include \"cutlass/util/host_tensor.h\"\n"
 | |
| "#include \"cutlass/util/tensor_view_io.h\"\n"
 | |
| "#include \"cutlass/util/reference/host/tensor_fill.h\"\n"
 | |
| "#include \"cutlass/util/reference/host/tensor_copy.h\"\n"
 | |
| "#include \"cutlass/util/reference/host/tensor_compare.h\"\n"
 | |
| "#include \"cutlass/util/reference/host/gemm.h\"\n"
 | |
| "\n"
 | |
| "#include \"testbed.h\"\n"
 | |
| "\n")
 | |
|         foundThreadblockTilesL0 = {}
 | |
|         foundThreadblockTilesL1 = {}
 | |
| 
 | |
|         ########################################################################
 | |
|         # for each combination of tile sizes
 | |
|         ########################################################################
 | |
|         for warpsPerThreadblock in warpsPerThreadblocks:
 | |
|             for warpShape in warpShapes:
 | |
|                 warpThreadsM = 0
 | |
|                 if warpShape[0] > warpShape[1]:
 | |
|                     warpThreadsM = 8
 | |
|                 else:
 | |
|                     warpThreadsM = 4
 | |
|                 warpThreadsN = warpNumThreads / warpThreadsM
 | |
| 
 | |
|                 # skip shapes with conflicting rectangularity
 | |
|                 # they are unlikely to be fastest
 | |
|                 blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
 | |
|                 blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
 | |
|                 warpG = warpShape[0] > warpShape[1]
 | |
|                 warpL = warpShape[0] < warpShape[1]
 | |
| 
 | |
|                 blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2
 | |
|                 blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1]
 | |
|                 warpG2 = warpShape[0] > warpShape[1]*2
 | |
|                 warpL2 = warpShape[0]*2 < warpShape[1]
 | |
| 
 | |
|                 if blockG2 and warpL: continue
 | |
|                 if blockL2 and warpG: continue
 | |
|                 if warpG2 and blockL: continue
 | |
|                 if warpL2 and blockG: continue
 | |
| 
 | |
|                 # check threadblock ratios and max
 | |
|                 threadblockTile = [warpShape[0]*warpsPerThreadblock[0],
 | |
|                         warpShape[1]*warpsPerThreadblock[1]]
 | |
|                 if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue
 | |
|                 if threadblockTile[0] > threadblockEdgeMax: continue
 | |
|                 if threadblockTile[1] > threadblockEdgeMax: continue
 | |
|                 totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1]
 | |
| 
 | |
|                 # calculate unroll
 | |
|                 # ensure that every iteration at least a full load of A,B are done
 | |
|                 unrollMin = 8
 | |
|                 unrollMin0 = totalThreads / threadblockTile[0]
 | |
|                 unrollMin1 = totalThreads / threadblockTile[1]
 | |
|                 unroll = max(unrollMin, unrollMin0, unrollMin1)
 | |
| 
 | |
|                 threadTileM = warpShape[0] / warpThreadsM
 | |
|                 threadTileN = warpShape[1] / warpThreadsN
 | |
|                 if threadTileM < 2 or threadTileN < 2: continue
 | |
|                 if threadTileM*threadTileN*precisionBits > 8*8*32: continue
 | |
| 
 | |
|                 # epilogue currently only supports N < WarpNumThreads
 | |
|                 if threadblockTile[1] < warpNumThreads: continue
 | |
| 
 | |
|                 # limit smem
 | |
|                 smemBitsA = threadblockTile[0]*unroll*2*precisionBits
 | |
|                 smemBitsB = threadblockTile[1]*unroll*2*precisionBits
 | |
|                 smemKBytes = (smemBitsA+smemBitsB)/8/1024
 | |
|                 if (smemKBytes > 48): continue
 | |
| 
 | |
|                 # test level 0
 | |
|                 testLevel = -1
 | |
|                 for tileId in range(0, len(threadblockTilesL0)):
 | |
|                     tbTile = threadblockTilesL0[tileId]
 | |
|                     if tbTile[0] == threadblockTile[0] and tbTile[1] == threadblockTile[1]:
 | |
|                         if tuple(tbTile) not in foundThreadblockTilesL0:
 | |
|                             testLevel = 0
 | |
|                             numL0 += 1
 | |
|                             foundThreadblockTilesL0[tuple(tbTile)] = True
 | |
| 
 | |
|                 # test level 1
 | |
|                 if testLevel < 0:
 | |
|                     threadblockTileAlreadyUsed = False
 | |
|                     if tuple(threadblockTile) not in foundThreadblockTilesL1:
 | |
|                         testLevel = 1
 | |
|                         numL1 += 1
 | |
|                         foundThreadblockTilesL1[tuple(threadblockTile)] = True
 | |
| 
 | |
|                 # test level 2
 | |
|                 if testLevel < 0:
 | |
|                     testLevel = 2
 | |
|                     numL2 += 1
 | |
| 
 | |
|                 ################################################################
 | |
|                 # write this tile to file
 | |
|                 ################################################################
 | |
| 
 | |
|                 print("%ix%ix%i__%ix%i_%ix%i_%ix%i L%i" % (
 | |
|                         threadblockTile[0], threadblockTile[1], unroll,
 | |
|                         threadTileM, threadTileN,
 | |
|                         warpThreadsM, warpThreadsN,
 | |
|                         warpsPerThreadblock[0], warpsPerThreadblock[1], testLevel))
 | |
| 
 | |
|                 out.write("////////////////////////////////////////////////////////////////////////////////\n"
 | |
|                         "// Elements / Thread: %3i x %3i\n"
 | |
|                         "//    Threads / Warp: %3i x %3i\n"
 | |
|                         "//     Warps / Block: %3i x %3i\n"
 | |
|                         "//       Threadblock: %3i x %3i x %2i\n"
 | |
|                         % ( threadTileM, threadTileN,
 | |
|                             warpThreadsM, warpThreadsN,
 | |
|                             warpsPerThreadblock[0], warpsPerThreadblock[1],
 | |
|                             threadblockTile[0], threadblockTile[1], unroll
 | |
|                             )
 | |
|                         )
 | |
| 
 | |
|                 out.write("CUTLASS_TEST_L%i(SM50_device_%sgemm_%s%s, %ix%ix%i_%ix%ix1_%ix%i_%ix%i_%ix%i, {\n" % (
 | |
|                     testLevel,
 | |
|                     precisionChar,
 | |
|                     transCharA,
 | |
|                     transCharB,
 | |
|                     threadblockTile[0],
 | |
|                     threadblockTile[1],
 | |
|                     unroll,
 | |
|                     warpShape[0],
 | |
|                     warpShape[1],
 | |
|                     threadTileM,
 | |
|                     threadTileN,
 | |
|                     warpThreadsM,
 | |
|                     warpThreadsN,
 | |
|                     warpsPerThreadblock[0],
 | |
|                     warpsPerThreadblock[1]
 | |
|                     ))
 | |
|                 out.write("    using precision = %s;\n" % precisionType)
 | |
|                 out.write("    using ThreadblockShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n" % (
 | |
|                     threadblockTile[0],
 | |
|                     threadblockTile[1],
 | |
|                     unroll))
 | |
|                 out.write("    using WarpShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n\n" % (
 | |
|                     warpShape[0],
 | |
|                     warpShape[1],
 | |
|                     unroll))
 | |
|                 out.write("    static int const kEpilogueElementsPerAccess = 1;\n"
 | |
|                     "    using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n"
 | |
|                     "    using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<\n"
 | |
|                     "        precision, kEpilogueElementsPerAccess, precision, precision>;\n\n")
 | |
| 
 | |
|                 out.write("    using Gemm = cutlass::gemm::device::Gemm<\n"
 | |
|                     "        precision, cutlass::layout::%sMajor,\n"
 | |
|                     "        precision, cutlass::layout::%sMajor,\n"
 | |
|                     "        precision, cutlass::layout::RowMajor,\n"
 | |
|                     "        precision,\n"
 | |
|                     "        cutlass::arch::OpClassSimt,\n"
 | |
|                     "        cutlass::arch::Sm50,\n"
 | |
|                     "        ThreadblockShape, WarpShape, InstructionShape,\n"
 | |
|                     "        EpilogueOutputOp,\n"
 | |
|                     "        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle,\n"
 | |
|                     "        2 // Stages\n"
 | |
|                     "    >;\n" % (
 | |
|                         "Column" if columnMajorA else "Row",
 | |
|                         "Column" if columnMajorB else "Row",
 | |
|                         ))
 | |
|                 out.write("    EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());\n"
 | |
|                     "} )\n\n")
 | |
| 
 | |
| 
 | |
|         out.close()
 | |
| print("NumKernels:", numL0, numL1, numL2)
 | |
| 
 | 
