cutlass/tools/library/src/reference/gemm_reference_operation.h

473 lines
14 KiB
C
Raw Normal View History

/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines reference operations for GEMM operation kinds in CUTLASS Library
*/
#pragma once
#include <iostream>
#include <sstream>
#include <cstring>
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/util.h"
#include "library_internal.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "cutlass/util/reference/device/gemm_complex.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
Provider Provider_,
typename ElementA_,
typename LayoutA_,
cutlass::ComplexTransform TransformA,
typename ElementB_,
typename LayoutB_,
cutlass::ComplexTransform TransformB,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
class GemmReferenceOperation : public Operation {
public:
static Provider const kProvider = Provider_;
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA, LayoutA>;
static cutlass::ComplexTransform const kTransformA = TransformA;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB, LayoutB>;
static cutlass::ComplexTransform const kTransformB = TransformB;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC, LayoutC>;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementAccumulator_;
using ConvertOp = ConvertOp_;
using InnerProductOp = InnerProductOp_;
protected:
/// Storage for the name string
std::string name_;
///
GemmDescription description_;
public:
/// Constructor
GemmReferenceOperation() {
// Basic information
description_.provider = kProvider;
description_.kind = OperationKind::kGemm;
description_.gemm_kind = GemmKind::kUniversal;
// Tensor description
description_.A = make_TensorDescription<ElementA, LayoutA>();
description_.transform_A = ComplexTransformMap<kTransformA>::kId;
description_.B = make_TensorDescription<ElementB, LayoutB>();
description_.transform_B = ComplexTransformMap<kTransformB>::kId;
description_.C = make_TensorDescription<ElementC, LayoutC>();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
// Compute capability for gemm reference
description_.tile_description.minimum_compute_capability =
(kProvider == Provider::kReferenceDevice ? 50 : 0);
description_.tile_description.maximum_compute_capability = 1024;
// Procedural name
std::stringstream ss;
ss << "gemm"
<< "_reference_" << to_string(description_.provider)
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
name_ = ss.str();
description_.name = name_.c_str();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
virtual Status can_implement(
void const *configuration,
void const *arguments) const {
return Status::kSuccess;
}
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
return sizeof(GemmUniversalConfiguration);
}
virtual uint64_t get_device_workspace_size(
void const *configuration) const {
return 0;
}
virtual Status initialize(
void const *configuration,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration));
return Status::kSuccess;
}
virtual Status run(
void const *arguments,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
GemmUniversalConfiguration const &config = *static_cast<GemmUniversalConfiguration const *>(host_workspace);
GemmUniversalArguments const &args = *static_cast<GemmUniversalArguments const *>(arguments);
ElementCompute alpha;
ElementCompute beta;
alpha = *static_cast<ElementCompute const *>(args.alpha);
beta = *static_cast<ElementCompute const *>(args.beta);
TensorRefA ref_A{static_cast<ElementA *>(const_cast<void *>(args.A)), LayoutA(int(config.lda))};
TensorRefB ref_B{static_cast<ElementB *>(const_cast<void *>(args.B)), LayoutB(int(config.ldb))};
TensorRefC ref_C{static_cast<ElementC *>(const_cast<void *>(args.C)), LayoutC(int(config.ldc))};
TensorRefC ref_D{static_cast<ElementC *>(args.D), LayoutC(int(config.ldd))};
if (kProvider == Provider::kReferenceHost) {
cutlass::reference::host::GemmComplex<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp
>(
config.problem_size,
alpha,
ref_A,
kTransformA,
ref_B,
kTransformB,
beta,
ref_C,
ref_D,
ElementAccumulator(),
config.batch_count,
args.batch_stride_A,
args.batch_stride_B,
args.batch_stride_C,
args.batch_stride_D
);
return Status::kSuccess;
}
else if (kProvider == Provider::kReferenceDevice) {
cutlass::reference::device::GemmComplex<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp
>(
config.problem_size,
alpha,
ref_A,
kTransformA,
ref_B,
kTransformB,
beta,
ref_C,
ref_D,
ElementAccumulator(),
config.batch_count,
args.batch_stride_A,
args.batch_stride_B,
args.batch_stride_C,
args.batch_stride_D
);
return Status::kSuccess;
}
return Status::kErrorNotSupported;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename ElementA_,
typename LayoutA_,
cutlass::ComplexTransform TransformA,
typename ElementB_,
typename LayoutB_,
cutlass::ComplexTransform TransformB,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_gemm(Manifest &manifest) {
manifest.append(new GemmReferenceOperation<
Provider::kReferenceHost,
ElementA_, LayoutA_, TransformA,
ElementB_, LayoutB_, TransformB,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
manifest.append(new GemmReferenceOperation<
Provider::kReferenceDevice,
ElementA_, LayoutA_, TransformA,
ElementB_, LayoutB_, TransformB,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
}
/// Helper to create NN, NT, TN, and TT GEMM layouts.
template <
typename ElementA_, cutlass::ComplexTransform TransformA,
typename ElementB_, cutlass::ComplexTransform TransformB,
typename ElementC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_gemm_canonical_layouts(Manifest &manifest) {
make_gemm<
ElementA_, cutlass::layout::ColumnMajor, TransformA,
ElementB_, cutlass::layout::ColumnMajor, TransformB,
ElementC_, cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
make_gemm<
ElementA_, cutlass::layout::ColumnMajor, TransformA,
ElementB_, cutlass::layout::RowMajor, TransformB,
ElementC_, cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
make_gemm<
ElementA_, cutlass::layout::RowMajor, TransformA,
ElementB_, cutlass::layout::ColumnMajor, TransformB,
ElementC_, cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
make_gemm<
ElementA_, cutlass::layout::RowMajor, TransformA,
ElementB_, cutlass::layout::RowMajor, TransformB,
ElementC_, cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
}
/// Helper to create TN and interleaved layouts GEMM layouts.
template <
int InterleaveK,
typename ElementA_,
typename ElementB_,
typename ElementC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_gemm_interleaved_layouts(Manifest &manifest) {
make_gemm<
ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone,
ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone,
ElementC_, cutlass::layout::ColumnMajor,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
}
/// Helper to real-valued GEMM with canonical layouts
template <
typename ElementA_,
typename ElementB_,
typename ElementC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_gemm_real_canonical_layouts(Manifest &manifest) {
make_gemm_canonical_layouts<
ElementA_, cutlass::ComplexTransform::kNone,
ElementB_, cutlass::ComplexTransform::kNone,
ElementC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
}
// Helper to create all complex transformation permutations
template <
typename ElementA_,
typename ElementB_,
typename ElementC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_gemm_complex_canonical_layouts(Manifest &manifest) {
make_gemm_canonical_layouts<
ElementA_, cutlass::ComplexTransform::kNone,
ElementB_, cutlass::ComplexTransform::kNone,
ElementC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
make_gemm_canonical_layouts<
ElementA_, cutlass::ComplexTransform::kConjugate,
ElementB_, cutlass::ComplexTransform::kConjugate,
ElementC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
make_gemm_canonical_layouts<
ElementA_, cutlass::ComplexTransform::kNone,
ElementB_, cutlass::ComplexTransform::kConjugate,
ElementC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
make_gemm_canonical_layouts<
ElementA_, cutlass::ComplexTransform::kConjugate,
ElementB_, cutlass::ComplexTransform::kNone,
ElementC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////