| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  | /*************************************************************************************************** | 
					
						
							| 
									
										
										
										
											2024-01-17 03:37:22 +08:00
										 |  |  |  * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     \brief Unit tests for thread-level GEMM | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | #include "../../common/cutlass_unit_test.h" | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | #include "cutlass/epilogue/epilogue_workspace.h" | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | namespace test { | 
					
						
							|  |  |  | namespace gemm { | 
					
						
							|  |  |  | namespace threadblock { | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | /// Kernel computes accumulator data and stores it out | 
					
						
							|  |  |  | template <typename Epilogue> | 
					
						
							|  |  |  | __global__ void kernel_epilogue_workspace(typename Epilogue::Params params) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   __shared__ typename Epilogue::SharedStorage shared_storage; | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   int warp_id = threadIdx.y; | 
					
						
							|  |  |  |   int lane_id = threadIdx.x; | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   Epilogue epilogue(params, shared_storage, warp_id, lane_id); | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  |   // | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   // Initialize accumulator tile | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  |   // | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   typename Epilogue::FragmentC accum; | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |   for (int i = 0; i < Epilogue::FragmentC::kElements; ++i) { | 
					
						
							|  |  |  |     accum[i] = Element(warp_id * blockDim.x + lane_id); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   // | 
					
						
							|  |  |  |   // Efficient epilogue | 
					
						
							|  |  |  |   // | 
					
						
							| 
									
										
										
										
											2018-05-17 02:44:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   cutlass::GemmCoord tb_tile_coord{blockIdx.x, blockIdx.y, 0}; | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   cutlass::GemmCoord problem_size =  | 
					
						
							|  |  |  |     tb_tile_coord *  | 
					
						
							|  |  |  |     cutlass::GemmCoord{Epilogue::Shape::kM, Epilogue::Shape::kN, 1}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Store accumulators | 
					
						
							|  |  |  |   epilogue( | 
					
						
							|  |  |  |     problem_size,  | 
					
						
							|  |  |  |     tb_tile_coord,  | 
					
						
							|  |  |  |     accum); | 
					
						
							| 
									
										
										
										
											2018-09-19 07:58:03 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | } // namespace threadblock | 
					
						
							|  |  |  | } // namespace gemm | 
					
						
							|  |  |  | } // namespace test | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | TEST(SM75_gemm_threadblock_epilogue_workspace, tensor_op_128x128_64x64) { | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   // Define an instance of the epilogue and see if it works | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  |   // | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   static int const kWarpCount = 4; | 
					
						
							|  |  |  |   static int const kWarpSize = 32; | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   using Shape = cutlass::MatrixShape<128, 128>; | 
					
						
							|  |  |  |   using FragmentC = cutlass::Array<int, Shape::kCount / (kWarpCount * kWarpSize)>; | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   using Epilogue = cutlass::gemm::threadblock::EpilogueWorkspace< | 
					
						
							|  |  |  |     Shape, | 
					
						
							|  |  |  |     kWarpCount, | 
					
						
							|  |  |  |     FragmentC | 
					
						
							|  |  |  |   >; | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   typename Epilogue::Params params( | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |   ); | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   // Launch the kernel | 
					
						
							|  |  |  |   dim3 grid(1,1); | 
					
						
							|  |  |  |   dim3 block(kWarpSize, kWarpCount); | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   test::gemm::threadblock::kernel_epilogue_workspace<Epilogue><<< grid, block >>>( | 
					
						
							|  |  |  |     params | 
					
						
							|  |  |  |   ); | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   cudaError_t result = cudaDeviceSynchronize(); | 
					
						
							|  |  |  |   EXPECT_EQ(result, cudaSuccess) << "Kernel launch error - " << cudaGetErrorString(result); | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   // | 
					
						
							|  |  |  |   //  | 
					
						
							|  |  |  |   // | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// |