177 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
		
		
			
		
	
	
			177 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
|   | /***************************************************************************************************
 | ||
|  |  * Copyright (c) 2017-2018, NVIDIA CORPORATION.  All rights reserved. | ||
|  |  * | ||
|  |  * Redistribution and use in source and binary forms, with or without modification, are permitted | ||
|  |  * provided that the following conditions are met: | ||
|  |  *     * Redistributions of source code must retain the above copyright notice, this list of | ||
|  |  *       conditions and the following disclaimer. | ||
|  |  *     * Redistributions in binary form must reproduce the above copyright notice, this list of | ||
|  |  *       conditions and the following disclaimer in the documentation and/or other materials | ||
|  |  *       provided with the distribution. | ||
|  |  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used | ||
|  |  *       to endorse or promote products derived from this software without specific prior written | ||
|  |  *       permission. | ||
|  |  * | ||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR | ||
|  |  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | ||
|  |  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE | ||
|  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | ||
|  |  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | ||
|  |  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, | ||
|  |  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
|  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|  |  * | ||
|  |  **************************************************************************************************/ | ||
|  | /*! \file
 | ||
|  |     \brief Reference implementation for GEMM in host-side code. | ||
|  | */ | ||
|  | 
 | ||
|  | #pragma once
 | ||
|  | 
 | ||
|  | #include "cutlass/coord.h"
 | ||
|  | #include "cutlass/matrix_traits.h"
 | ||
|  | #include "cutlass/tensor_view.h"
 | ||
|  | #include "cutlass/gemm/gemm_coord.h"
 | ||
|  | 
 | ||
|  | #include "tools/util/reference/detail/inner_product.h"
 | ||
|  | 
 | ||
|  | namespace cutlass { | ||
|  | namespace reference { | ||
|  | namespace device { | ||
|  | namespace thread { | ||
|  | 
 | ||
|  | ////////////////////////////////////////////////////////////////////////////////////////////////////
 | ||
|  | 
 | ||
|  | /// Thread-level blocked general matrix product.
 | ||
|  | //
 | ||
|  | // Note, this is a reference implementation. Performance is not expected to approach peak.
 | ||
|  | //
 | ||
|  | template < | ||
|  |   typename TensorRefA, | ||
|  |   typename TensorRefB, | ||
|  |   typename TensorRefC, | ||
|  |   typename ScalarType, | ||
|  |   typename AccumulatorType, | ||
|  |   typename OutputTile | ||
|  | > | ||
|  | struct Gemm { | ||
|  | 
 | ||
|  |   typedef typename TensorRefA::Storage ScalarA; | ||
|  |   typedef typename TensorRefB::Storage ScalarB; | ||
|  |   typedef typename TensorRefC::Storage ScalarC; | ||
|  | 
 | ||
|  |   //
 | ||
|  |   // Data members
 | ||
|  |   //
 | ||
|  | 
 | ||
|  |   /// Tile for A operand
 | ||
|  |   ScalarA A_tile[OutputTile::kW]; | ||
|  | 
 | ||
|  |   /// Tile for B operand
 | ||
|  |   ScalarB B_tile[OutputTile::kH]; | ||
|  | 
 | ||
|  |   /// Tile for Accumulator
 | ||
|  |   AccumulatorType accum[OutputTile::kH][OutputTile::kW]; | ||
|  | 
 | ||
|  |   //
 | ||
|  |   // Methods
 | ||
|  |   //
 | ||
|  | 
 | ||
|  |   /// Constructor
 | ||
|  |   CUTLASS_HOST_DEVICE | ||
|  |   Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { | ||
|  | 
 | ||
|  |     // Clear fetch registers
 | ||
|  |     for (int i = 0; i < OutputTile::kW; ++i) { | ||
|  |       A_tile[i] = ScalarA(0); | ||
|  |     } | ||
|  | 
 | ||
|  |     for (int j = 0; j < OutputTile::kW; ++j) { | ||
|  |       B_tile[j] = ScalarB(0); | ||
|  |     } | ||
|  | 
 | ||
|  |     // Clear accumulators
 | ||
|  |     CUTLASS_PRAGMA_UNROLL | ||
|  |     for (int j = 0; j < OutputTile::kH; ++j) { | ||
|  |       CUTLASS_PRAGMA_UNROLL | ||
|  |       for (int i = 0; i < OutputTile::kW; ++i) { | ||
|  |         accum[j][i] = initial_accum; | ||
|  |       } | ||
|  |     } | ||
|  |   } | ||
|  | 
 | ||
|  |   /// Computes a matrix product
 | ||
|  |   CUTLASS_HOST_DEVICE | ||
|  |   Gemm & multiply_add( | ||
|  |     gemm::GemmCoord problem_size, | ||
|  |     TensorRefA tensor_a, | ||
|  |     TensorRefB tensor_b, | ||
|  |     MatrixCoord output_coord = MatrixCoord()) { | ||
|  | 
 | ||
|  |     // Loop over the GEMM K dimension
 | ||
|  |     CUTLASS_PRAGMA_NO_UNROLL | ||
|  |     for (int k = 0; k < problem_size.k(); ++k) { | ||
|  | 
 | ||
|  |       // Fetch a slice of the A matrix
 | ||
|  |       CUTLASS_PRAGMA_UNROLL | ||
|  |       for (int i = 0; i < OutputTile::kW; ++i) { | ||
|  |         if (output_coord.row() + i < problem_size.m()) { | ||
|  |           A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); | ||
|  |         } | ||
|  |       } | ||
|  | 
 | ||
|  |       // Fetch a slice of the B matrix
 | ||
|  |       CUTLASS_PRAGMA_UNROLL | ||
|  |       for (int j = 0; j < OutputTile::kH; ++j) { | ||
|  |         if (output_coord.column() + j < problem_size.n()) { | ||
|  |           B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); | ||
|  |         } | ||
|  |       } | ||
|  | 
 | ||
|  |       // Compute an accumulated matrix product
 | ||
|  |       CUTLASS_PRAGMA_UNROLL | ||
|  |       for (int j = 0; j < OutputTile::kH; ++j) { | ||
|  |         CUTLASS_PRAGMA_UNROLL | ||
|  |         for (int i = 0; i < OutputTile::kW; ++i) { | ||
|  |           accum[j][i] = detail::inner_product(A_tile[i], B_tile[j], accum[j][i]); | ||
|  |         } | ||
|  |       } | ||
|  |     } | ||
|  | 
 | ||
|  |     return *this; | ||
|  |   } | ||
|  | 
 | ||
|  |   /// Performs linear scaling of matrix product and updates output tensor
 | ||
|  |   CUTLASS_HOST_DEVICE | ||
|  |   Gemm & epilogue( | ||
|  |     gemm::GemmCoord problem_size, | ||
|  |     ScalarType alpha, | ||
|  |     ScalarType beta, | ||
|  |     TensorRefC tensor_c, | ||
|  |     MatrixCoord output_coord = MatrixCoord()) { | ||
|  | 
 | ||
|  |     // Update the output tensor
 | ||
|  |     for (int j = 0; j < OutputTile::kH; ++j) { | ||
|  |       for (int i = 0; i < OutputTile::kW; ++i) { | ||
|  |         MatrixCoord coord = output_coord + MatrixCoord(i, j); | ||
|  |         if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { | ||
|  | 
 | ||
|  |           tensor_c.at(coord) = detail::Cast<ScalarType, ScalarC>::apply( | ||
|  |             alpha * ScalarType(accum[j][i]) + | ||
|  |             beta * ScalarType(tensor_c.at(coord)) | ||
|  |           ); | ||
|  |         } | ||
|  |       } | ||
|  |     } | ||
|  | 
 | ||
|  |     return *this; | ||
|  |   } | ||
|  | }; | ||
|  | 
 | ||
|  | ////////////////////////////////////////////////////////////////////////////////////////////////////
 | ||
|  | 
 | ||
|  | } // namespace thread
 | ||
|  | } // namespace device
 | ||
|  | } // namespace reference
 | ||
|  | } // namespace cutlass
 |