| 
									
										
										
										
											2018-10-27 05:38:46 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2019-03-21 01:49:17 +08:00
										 |  |  |  * Copyright (c) 2017-2019, NVIDIA CORPORATION.  All rights reserved. | 
					
						
							| 
									
										
										
										
											2018-10-27 05:38:46 +08:00
										 |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without modification, are permitted | 
					
						
							|  |  |  |  * provided that the following conditions are met: | 
					
						
							|  |  |  |  *     * Redistributions of source code must retain the above copyright notice, this list of | 
					
						
							|  |  |  |  *       conditions and the following disclaimer. | 
					
						
							|  |  |  |  *     * Redistributions in binary form must reproduce the above copyright notice, this list of | 
					
						
							|  |  |  |  *       conditions and the following disclaimer in the documentation and/or other materials | 
					
						
							|  |  |  |  *       provided with the distribution. | 
					
						
							|  |  |  |  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used | 
					
						
							|  |  |  |  *       to endorse or promote products derived from this software without specific prior written | 
					
						
							|  |  |  |  *       permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | 
					
						
							|  |  |  |  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | 
					
						
							|  |  |  |  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | 
					
						
							|  |  |  |  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, | 
					
						
							|  |  |  |  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file
 | 
					
						
							|  |  |  |     \brief Reference implementation for GEMM in host-side code. | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/coord.h"
 | 
					
						
							|  |  |  | #include "cutlass/matrix_traits.h"
 | 
					
						
							|  |  |  | #include "cutlass/tensor_view.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/gemm_coord.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "tools/util/reference/detail/inner_product.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace reference { | 
					
						
							|  |  |  | namespace device { | 
					
						
							|  |  |  | namespace thread { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ////////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Thread-level blocked general matrix product.
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Note, this is a reference implementation. Performance is not expected to approach peak.
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   typename TensorRefA, | 
					
						
							|  |  |  |   typename TensorRefB, | 
					
						
							|  |  |  |   typename TensorRefC, | 
					
						
							|  |  |  |   typename ScalarType, | 
					
						
							|  |  |  |   typename AccumulatorType, | 
					
						
							|  |  |  |   typename OutputTile | 
					
						
							|  |  |  | > | 
					
						
							|  |  |  | struct Gemm { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   typedef typename TensorRefA::Storage ScalarA; | 
					
						
							|  |  |  |   typedef typename TensorRefB::Storage ScalarB; | 
					
						
							|  |  |  |   typedef typename TensorRefC::Storage ScalarC; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Data members
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Tile for A operand
 | 
					
						
							|  |  |  |   ScalarA A_tile[OutputTile::kW]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Tile for B operand
 | 
					
						
							|  |  |  |   ScalarB B_tile[OutputTile::kH]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Tile for Accumulator
 | 
					
						
							|  |  |  |   AccumulatorType accum[OutputTile::kH][OutputTile::kW]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Methods
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Constructor
 | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Clear fetch registers
 | 
					
						
							|  |  |  |     for (int i = 0; i < OutputTile::kW; ++i) { | 
					
						
							|  |  |  |       A_tile[i] = ScalarA(0); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int j = 0; j < OutputTile::kW; ++j) { | 
					
						
							|  |  |  |       B_tile[j] = ScalarB(0); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Clear accumulators
 | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int j = 0; j < OutputTile::kH; ++j) { | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int i = 0; i < OutputTile::kW; ++i) { | 
					
						
							|  |  |  |         accum[j][i] = initial_accum; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Computes a matrix product
 | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   Gemm & multiply_add( | 
					
						
							|  |  |  |     gemm::GemmCoord problem_size, | 
					
						
							|  |  |  |     TensorRefA tensor_a, | 
					
						
							|  |  |  |     TensorRefB tensor_b, | 
					
						
							|  |  |  |     MatrixCoord output_coord = MatrixCoord()) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Loop over the GEMM K dimension
 | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_NO_UNROLL | 
					
						
							|  |  |  |     for (int k = 0; k < problem_size.k(); ++k) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Fetch a slice of the A matrix
 | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int i = 0; i < OutputTile::kW; ++i) { | 
					
						
							|  |  |  |         if (output_coord.row() + i < problem_size.m()) { | 
					
						
							|  |  |  |           A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Fetch a slice of the B matrix
 | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int j = 0; j < OutputTile::kH; ++j) { | 
					
						
							|  |  |  |         if (output_coord.column() + j < problem_size.n()) { | 
					
						
							|  |  |  |           B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute an accumulated matrix product
 | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int j = 0; j < OutputTile::kH; ++j) { | 
					
						
							|  |  |  |         CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |         for (int i = 0; i < OutputTile::kW; ++i) { | 
					
						
							|  |  |  |           accum[j][i] = detail::inner_product(A_tile[i], B_tile[j], accum[j][i]); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return *this; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Performs linear scaling of matrix product and updates output tensor
 | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   Gemm & epilogue( | 
					
						
							|  |  |  |     gemm::GemmCoord problem_size, | 
					
						
							|  |  |  |     ScalarType alpha, | 
					
						
							|  |  |  |     ScalarType beta, | 
					
						
							|  |  |  |     TensorRefC tensor_c, | 
					
						
							|  |  |  |     MatrixCoord output_coord = MatrixCoord()) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Update the output tensor
 | 
					
						
							|  |  |  |     for (int j = 0; j < OutputTile::kH; ++j) { | 
					
						
							|  |  |  |       for (int i = 0; i < OutputTile::kW; ++i) { | 
					
						
							|  |  |  |         MatrixCoord coord = output_coord + MatrixCoord(i, j); | 
					
						
							|  |  |  |         if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           tensor_c.at(coord) = detail::Cast<ScalarType, ScalarC>::apply( | 
					
						
							|  |  |  |             alpha * ScalarType(accum[j][i]) + | 
					
						
							|  |  |  |             beta * ScalarType(tensor_c.at(coord)) | 
					
						
							|  |  |  |           ); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return *this; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ////////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace thread
 | 
					
						
							|  |  |  | } // namespace device
 | 
					
						
							|  |  |  | } // namespace reference
 | 
					
						
							|  |  |  | } // namespace cutlass
 |