cutlass/tools/util/reference/device/kernel/gemm.h

149 lines
5.0 KiB
C
Raw Normal View History

2018-10-27 05:38:46 +08:00
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Reference implementation for GEMM in host-side code.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/tensor_view.h"
#include "cutlass/gemm/gemm_coord.h"
#include "tools/util/reference/device/thread/gemm.h"
namespace cutlass {
namespace reference {
namespace device {
namespace kernel {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
/// objects.
template <
typename TensorRefA,
typename TensorRefB,
typename TensorRefC,
typename ScalarType,
typename AccumulatorType,
typename OutputTile
>
__global__ void Gemm(
gemm::GemmCoord problem_size,
ScalarType alpha,
TensorRefA tensor_a,
TensorRefB tensor_b,
ScalarType beta,
TensorRefC tensor_c,
AccumulatorType initial_accum) {
// Map each thread to a unique tile of the output matrix
MatrixCoord output_coord(
(threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kW,
(threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kH
);
// Compute the general matrix product
thread::Gemm<
TensorRefA,
TensorRefB,
TensorRefC,
ScalarType,
AccumulatorType,
OutputTile
> gemm(initial_accum);
gemm.multiply_add(
problem_size,
tensor_a,
tensor_b,
output_coord);
gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
/// objects.
template <
typename TensorRefCollectionA,
typename TensorRefCollectionB,
typename TensorRefCollectionC,
typename ScalarType,
typename AccumulatorType,
typename OutputTile
>
__global__ void BatchedGemm(
gemm::GemmCoord problem_size,
ScalarType alpha,
TensorRefCollectionA tensor_collection_a,
TensorRefCollectionB tensor_collection_b,
ScalarType beta,
TensorRefCollectionC tensor_collection_c,
AccumulatorType initial_accum) {
// Obtain batch ID
int batch_id = blockIdx.z;
// Dereference based on batch_id
typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id);
typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id);
typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id);
// Map each thread to a unique tile of the output matrix
MatrixCoord output_coord(
(threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kW,
(threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kH
);
// Compute the general matrix product
thread::Gemm<
typename TensorRefCollectionA::TensorRef,
typename TensorRefCollectionB::TensorRef,
typename TensorRefCollectionC::TensorRef,
ScalarType,
AccumulatorType,
OutputTile
> gemm(initial_accum);
gemm.multiply_add(
problem_size,
tensor_a,
tensor_b,
output_coord);
gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace device
} // namespace reference
} // namespace cutlass