2020-04-08 04:51:25 +08:00
|
|
|
/***************************************************************************************************
|
2021-02-26 22:58:26 +08:00
|
|
|
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
2020-04-08 04:51:25 +08:00
|
|
|
*
|
|
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
|
|
* provided that the following conditions are met:
|
|
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
|
|
* conditions and the following disclaimer.
|
|
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
|
|
* provided with the distribution.
|
|
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
|
|
* permission.
|
|
|
|
*
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
|
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*
|
|
|
|
**************************************************************************************************/
|
|
|
|
|
|
|
|
/*! \file
|
|
|
|
\brief CUTLASS Library handle.
|
|
|
|
*/
|
2020-06-09 07:17:35 +08:00
|
|
|
#include <iostream>
|
2020-04-08 04:51:25 +08:00
|
|
|
#include <stdexcept>
|
|
|
|
#include <cstdint>
|
|
|
|
|
|
|
|
#include "cutlass/library/handle.h"
|
|
|
|
#include "cutlass/library/singleton.h"
|
|
|
|
#include "cutlass/library/util.h"
|
|
|
|
|
|
|
|
namespace cutlass {
|
|
|
|
namespace library {
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Constructor
|
|
|
|
Handle::Handle(
|
|
|
|
cudaStream_t stream,
|
|
|
|
size_t workspace_size
|
2020-06-09 07:17:35 +08:00
|
|
|
):
|
|
|
|
provider_(Provider::kCUTLASS),
|
2020-04-08 04:51:25 +08:00
|
|
|
stream_(stream),
|
|
|
|
workspace_(nullptr),
|
|
|
|
workspace_size_(0),
|
|
|
|
scalar_pointer_mode_(ScalarPointerMode::kHost),
|
|
|
|
last_operation_(nullptr) {
|
|
|
|
|
|
|
|
int device_idx = -1;
|
|
|
|
|
|
|
|
cudaError_t error = cudaGetDevice(&device_idx);
|
|
|
|
if (error != cudaSuccess) {
|
|
|
|
throw std::runtime_error("cudaGetDevice() failed");
|
|
|
|
}
|
|
|
|
|
|
|
|
error = cudaGetDeviceProperties(&device_, device_idx);
|
|
|
|
if (error != cudaSuccess) {
|
|
|
|
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
|
|
|
}
|
|
|
|
|
|
|
|
set_workspace_size(workspace_size);
|
|
|
|
|
|
|
|
Singleton::get();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Destructor
|
|
|
|
Handle::~Handle() {
|
|
|
|
if (workspace_) {
|
|
|
|
|
|
|
|
if (workspace_) {
|
|
|
|
cudaFree(workspace_);
|
|
|
|
}
|
|
|
|
|
|
|
|
workspace_ = nullptr;
|
|
|
|
workspace_size_ = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Move constructor
|
|
|
|
Handle::Handle(Handle && handle) {
|
|
|
|
device_ = handle.device_;
|
|
|
|
workspace_size_ = handle.workspace_size_;
|
|
|
|
workspace_ = handle.workspace_;
|
|
|
|
stream_ = handle.stream_;
|
|
|
|
scalar_pointer_mode_ = handle.scalar_pointer_mode_;
|
|
|
|
|
|
|
|
handle.workspace_ = nullptr;
|
|
|
|
handle.workspace_size_ = 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Move assignment operator
|
|
|
|
Handle & Handle::operator=(Handle && handle) {
|
|
|
|
|
2020-06-09 07:17:35 +08:00
|
|
|
provider_ = handle.provider_;
|
2020-04-08 04:51:25 +08:00
|
|
|
device_ = handle.device_;
|
|
|
|
workspace_size_ = handle.workspace_size_;
|
|
|
|
workspace_ = handle.workspace_;
|
|
|
|
stream_ = handle.stream_;
|
|
|
|
scalar_pointer_mode_ = handle.scalar_pointer_mode_;
|
|
|
|
|
|
|
|
handle.workspace_ = nullptr;
|
|
|
|
handle.workspace_size_ = 0;
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
int Handle::compute_capability() const {
|
|
|
|
return device_.major * 10 + device_.minor;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Sets the current CUDA stream
|
|
|
|
void Handle::set_stream(cudaStream_t stream) {
|
|
|
|
stream_ = stream;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Gets the current CUDA stream
|
|
|
|
cudaStream_t Handle::get_stream() const {
|
|
|
|
return stream_;
|
|
|
|
}
|
|
|
|
|
2020-06-09 07:17:35 +08:00
|
|
|
/// Gets the current provider
|
|
|
|
Provider Handle::get_provider() const {
|
|
|
|
return provider_;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Sets the provider of operations
|
|
|
|
void Handle::set_provider(Provider provider) {
|
|
|
|
provider_ = provider;
|
|
|
|
}
|
|
|
|
|
2020-04-08 04:51:25 +08:00
|
|
|
/// Gets the device workspace size
|
|
|
|
size_t Handle::get_workspace_size() const {
|
|
|
|
return workspace_size_;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Gets a pointer to the device workspace allocation in Global Memory
|
|
|
|
void *Handle::get_workspace() const {
|
|
|
|
return workspace_;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Sets the size of device workspace, invalidating previous calls to get_device_workspace()
|
|
|
|
void Handle::set_workspace_size(size_t bytes) {
|
|
|
|
if (bytes != workspace_size_) {
|
|
|
|
|
|
|
|
if (workspace_) {
|
|
|
|
cudaFree(workspace_);
|
|
|
|
}
|
|
|
|
|
|
|
|
workspace_ = nullptr;
|
|
|
|
workspace_size_ = bytes;
|
|
|
|
|
|
|
|
if (workspace_size_) {
|
|
|
|
|
|
|
|
cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_);
|
|
|
|
|
|
|
|
if (error != cudaSuccess) {
|
|
|
|
throw std::runtime_error("Failed to allocate workspace");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (workspace_) {
|
|
|
|
cudaError_t error = cudaMemset(workspace_, 0, workspace_size_);
|
|
|
|
|
|
|
|
if (error != cudaSuccess) {
|
|
|
|
throw std::runtime_error("Failed to clear workspace");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Gets the scalar pointer mode
|
|
|
|
ScalarPointerMode Handle::get_scalar_pointer_mode() const {
|
|
|
|
return scalar_pointer_mode_;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Sets the scalar pointer mode
|
|
|
|
void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) {
|
|
|
|
scalar_pointer_mode_ = mode;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Gets the last operation
|
|
|
|
Operation const *Handle::get_last_operation() const {
|
|
|
|
return last_operation_;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Returns the maximum required alignment for each operator
|
|
|
|
static int maximum_alignment_requirement(GemmDescription const &desc) {
|
|
|
|
return std::max(
|
|
|
|
std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the largest alignment (in units of elements) the problem satisfies, starting from a
|
|
|
|
/// given upper limit.
|
|
|
|
static int gemm_problem_alignment(
|
|
|
|
int M,
|
|
|
|
int N,
|
|
|
|
int K,
|
|
|
|
NumericTypeID element_A,
|
|
|
|
void const *ptr_A,
|
|
|
|
int lda,
|
|
|
|
int64_t batch_stride_A,
|
|
|
|
NumericTypeID element_B,
|
|
|
|
void const *ptr_B,
|
|
|
|
int ldb,
|
|
|
|
int64_t batch_stride_B,
|
|
|
|
NumericTypeID element_C,
|
|
|
|
void const * ptr_C,
|
|
|
|
int ldc,
|
|
|
|
int64_t batch_stride_C,
|
|
|
|
void const * ptr_D,
|
|
|
|
int ldd,
|
|
|
|
int64_t batch_stride_D,
|
|
|
|
int max_alignment_in_bytes = 16
|
|
|
|
) {
|
|
|
|
|
|
|
|
void const *pointers[] = {
|
|
|
|
ptr_A, ptr_B, ptr_C, ptr_D
|
|
|
|
};
|
|
|
|
|
|
|
|
int64_t extents[] = {
|
|
|
|
M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D
|
|
|
|
};
|
|
|
|
|
|
|
|
NumericTypeID elements[] = {
|
|
|
|
element_A, element_B, element_C
|
|
|
|
};
|
|
|
|
|
|
|
|
for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) {
|
|
|
|
|
|
|
|
bool satisfied = true;
|
|
|
|
|
|
|
|
// Can pointers satisfy this?
|
|
|
|
for (void const *ptr : pointers) {
|
|
|
|
std::uintptr_t int_ptr = reinterpret_cast<std::uintptr_t>(ptr);
|
|
|
|
|
|
|
|
if (int_ptr % max_alignment_in_bytes) {
|
|
|
|
satisfied = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!satisfied) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Compute the maximum alignment based on element data types
|
|
|
|
int max_element_alignment = 0;
|
|
|
|
|
|
|
|
for (NumericTypeID type_id : elements) {
|
|
|
|
int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id);
|
|
|
|
max_element_alignment = std::max(max_element_alignment, element_alignment);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Can the problem size and leading dimensions satisfy this?
|
|
|
|
for (int64_t extent : extents) {
|
|
|
|
if (extent % max_element_alignment) {
|
|
|
|
satisfied = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!satisfied) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Yes
|
|
|
|
return max_element_alignment;
|
|
|
|
}
|
|
|
|
|
|
|
|
// No alignment satisfies this problem
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Find the best kernel in descending order of preference.
|
|
|
|
static Operation const * find_gemm_operation(
|
|
|
|
GemmOperationFunctionalMap::const_iterator operators_it,
|
|
|
|
GemmPreferenceKey const preference_key) {
|
|
|
|
|
|
|
|
auto cc_it = operators_it->second.upper_bound(preference_key);
|
|
|
|
|
|
|
|
if (cc_it == operators_it->second.begin()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Operation const *operation = nullptr;
|
|
|
|
|
|
|
|
// Search in descending order of compute capability
|
|
|
|
do {
|
|
|
|
--cc_it;
|
|
|
|
|
|
|
|
// Search tile sizes in order, for now.
|
|
|
|
for (auto const * op : cc_it->second) {
|
|
|
|
|
|
|
|
GemmDescription const &desc = static_cast<GemmDescription const &>(op->description());
|
|
|
|
|
|
|
|
int min_cc = desc.tile_description.minimum_compute_capability;
|
|
|
|
int max_cc = desc.tile_description.maximum_compute_capability;
|
|
|
|
|
|
|
|
int op_alignment = maximum_alignment_requirement(desc);
|
|
|
|
|
|
|
|
if ((min_cc <= preference_key.compute_capability) &&
|
|
|
|
(preference_key.compute_capability <= max_cc) &&
|
|
|
|
(op_alignment <= preference_key.alignment)) {
|
|
|
|
|
|
|
|
operation = op;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} while (!operation && cc_it != operators_it->second.begin());
|
|
|
|
|
|
|
|
return operation;
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Executes a GEMM computation: D <= alpha * A*B + beta * C
|
|
|
|
Status Handle::gemm(
|
|
|
|
|
|
|
|
int M, /// GEMM M dimension
|
|
|
|
int N, /// GEMM N dimension
|
|
|
|
int K, /// GEMM K dimension
|
|
|
|
|
|
|
|
NumericTypeID element_compute, /// Data type of internal accumulation
|
|
|
|
|
|
|
|
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
|
|
|
|
|
|
|
void const *alpha, /// Pointer to alpha scalar
|
|
|
|
|
|
|
|
NumericTypeID element_A, /// Data type of A matrix elements
|
|
|
|
LayoutTypeID layout_A, /// Layout of A matrix
|
|
|
|
ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices
|
|
|
|
|
|
|
|
void const * ptr_A, /// Pointer to A matrix in Global Memory
|
|
|
|
int lda, /// Leading dimension of A matrix
|
|
|
|
|
|
|
|
NumericTypeID element_B, /// Data type of B matrix elements
|
|
|
|
LayoutTypeID layout_B, /// Layout of B matrix
|
|
|
|
ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices
|
|
|
|
|
|
|
|
void const * ptr_B, /// Pointer to B matrix in Global Memory
|
|
|
|
int ldb, /// Leading dimension of B matrix
|
|
|
|
|
|
|
|
void const * beta, /// Pointer to beta scalar
|
|
|
|
|
|
|
|
NumericTypeID element_C, /// Data type of C and D matrices
|
|
|
|
|
|
|
|
void const * ptr_C, /// Pointer to C matrix
|
|
|
|
int ldc, /// Leading dimension of C matrix
|
|
|
|
|
|
|
|
void * ptr_D, /// Pointer to D matrix
|
|
|
|
int ldd /// Leading dimension of D matrix
|
|
|
|
) {
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmFunctionalKey key(
|
2020-06-09 07:17:35 +08:00
|
|
|
provider_,
|
|
|
|
GemmKind::kGemm,
|
2020-04-08 04:51:25 +08:00
|
|
|
element_compute,
|
|
|
|
element_scalar,
|
|
|
|
element_A,
|
|
|
|
layout_A,
|
|
|
|
transform_A,
|
|
|
|
element_B,
|
|
|
|
layout_B,
|
|
|
|
transform_B,
|
|
|
|
element_C
|
|
|
|
);
|
|
|
|
|
|
|
|
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
|
|
|
|
|
|
|
if (operators_it == Singleton::get().operation_table.gemm_operations.end()) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (operators_it->second.empty()) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Compute the largest alignment restriction the kernel can satisfy.
|
|
|
|
//
|
|
|
|
|
|
|
|
// Maximum alignment expectation among all kernels (in units of bytes)
|
|
|
|
int const kMaximumAlignmentSize = 16;
|
|
|
|
|
|
|
|
int alignment = gemm_problem_alignment(
|
|
|
|
M, N, K,
|
|
|
|
element_A, ptr_A, lda, 0,
|
|
|
|
element_B, ptr_B, ldb, 0,
|
|
|
|
element_C, ptr_C, ldc, 0,
|
|
|
|
ptr_D, ldd, 0, kMaximumAlignmentSize
|
|
|
|
);
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the best kernel in descending order of preference.
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmPreferenceKey preference_key(compute_capability(), alignment);
|
|
|
|
|
|
|
|
Operation const *operation = find_gemm_operation(operators_it, preference_key);
|
|
|
|
|
|
|
|
if (!operation) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
last_operation_ = operation;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Configure operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmConfiguration configuration{
|
|
|
|
{M, N, K},
|
|
|
|
lda,
|
|
|
|
ldb,
|
|
|
|
ldc,
|
|
|
|
ldd,
|
|
|
|
1
|
|
|
|
};
|
|
|
|
|
|
|
|
// Query host work space size
|
|
|
|
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
char host_workspace[kHostWorkspaceSize];
|
|
|
|
|
|
|
|
// Query device workspace size
|
|
|
|
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialize host and device workspaces
|
|
|
|
Status status = operation->initialize(
|
|
|
|
&configuration,
|
|
|
|
host_workspace,
|
|
|
|
workspace_,
|
|
|
|
stream_);
|
|
|
|
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Run the operator
|
|
|
|
GemmArguments arguments{
|
|
|
|
ptr_A,
|
|
|
|
ptr_B,
|
|
|
|
ptr_C,
|
|
|
|
ptr_D,
|
|
|
|
alpha,
|
|
|
|
beta,
|
|
|
|
scalar_pointer_mode_
|
|
|
|
};
|
|
|
|
|
|
|
|
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2020-06-09 07:17:35 +08:00
|
|
|
/// Executes a GEMM computation: D <= alpha * A*B + beta * C.
|
|
|
|
//
|
|
|
|
// Supports batched-strided, batched array or split-K serial or split-K parallel.
|
|
|
|
//
|
|
|
|
Status Handle::gemm_universal(
|
|
|
|
|
|
|
|
GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched
|
|
|
|
|
|
|
|
int M, /// GEMM M dimension
|
|
|
|
int N, /// GEMM N dimension
|
|
|
|
int K, /// GEMM K dimension
|
|
|
|
|
|
|
|
NumericTypeID element_compute, /// Data type of internal accumulation
|
|
|
|
|
|
|
|
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
|
|
|
|
|
|
|
void const *alpha, /// Pointer to alpha scalar
|
|
|
|
|
|
|
|
NumericTypeID element_A, /// Data type of A matrix elements
|
|
|
|
LayoutTypeID layout_A, /// Layout of A matrix
|
|
|
|
ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices
|
|
|
|
|
|
|
|
void const * ptr_A, /// Pointer to A matrix in Global Memory
|
|
|
|
int lda, /// Leading dimension of A matrix
|
|
|
|
|
|
|
|
NumericTypeID element_B, /// Data type of B matrix elements
|
|
|
|
LayoutTypeID layout_B, /// Layout of B matrix
|
|
|
|
ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices
|
|
|
|
|
|
|
|
void const * ptr_B, /// Pointer to B matrix in Global Memory
|
|
|
|
int ldb, /// Leading dimension of B matrix
|
|
|
|
|
|
|
|
void const * beta, /// Pointer to beta scalar
|
|
|
|
|
|
|
|
NumericTypeID element_C, /// Data type of C and D matrices
|
|
|
|
|
|
|
|
void const * ptr_C, /// Pointer to C matrix
|
|
|
|
int ldc, /// Leading dimension of C matrix
|
|
|
|
|
|
|
|
void * ptr_D, /// Pointer to D matrix
|
|
|
|
int ldd, /// Leading dimension of D matrix
|
|
|
|
|
|
|
|
int batch_count, /// Batch count or number of split-K slices
|
|
|
|
|
|
|
|
int64_t batch_stride_A, /// Batch stride of A operand
|
|
|
|
int64_t batch_stride_B, /// Batch stride of B operand
|
|
|
|
int64_t batch_stride_C, /// Batch stride of C operand
|
|
|
|
int64_t batch_stride_D /// Batch stride of D operand
|
|
|
|
) {
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmFunctionalKey key(
|
|
|
|
provider_,
|
|
|
|
GemmKind::kUniversal,
|
|
|
|
element_compute,
|
|
|
|
element_scalar,
|
|
|
|
element_A,
|
|
|
|
layout_A,
|
|
|
|
transform_A,
|
|
|
|
element_B,
|
|
|
|
layout_B,
|
|
|
|
transform_B,
|
|
|
|
element_C
|
|
|
|
);
|
|
|
|
|
|
|
|
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
|
|
|
|
|
|
|
if (operators_it == Singleton::get().operation_table.gemm_operations.end()) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (operators_it->second.empty()) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Compute the largest alignment restriction the kernel can satisfy.
|
|
|
|
//
|
|
|
|
|
|
|
|
// Maximum alignment expectation among all kernels (in units of bytes)
|
|
|
|
int const kMaximumAlignmentSize = 16;
|
|
|
|
|
|
|
|
void const *ptr_A_check = ptr_A;
|
|
|
|
void const *ptr_B_check = ptr_B;
|
|
|
|
void const *ptr_C_check = ptr_C;
|
|
|
|
void * ptr_D_check = ptr_D;
|
|
|
|
|
|
|
|
// Ignore alignment of pointers to pointers. We can't check this from the host,
|
|
|
|
// as each batch index has its own pointer in device memory.
|
|
|
|
if (mode == GemmUniversalMode::kArray) {
|
|
|
|
ptr_A_check = nullptr;
|
|
|
|
ptr_B_check = nullptr;
|
|
|
|
ptr_C_check = nullptr;
|
|
|
|
ptr_D_check = nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
int alignment = gemm_problem_alignment(
|
|
|
|
M, N, K,
|
|
|
|
element_A, ptr_A_check, lda, 0,
|
|
|
|
element_B, ptr_B_check, ldb, 0,
|
|
|
|
element_C, ptr_C_check, ldc, 0,
|
|
|
|
ptr_D_check, ldd, 0, kMaximumAlignmentSize
|
|
|
|
);
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the best kernel in descending order of preference.
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmPreferenceKey preference_key(compute_capability(), alignment);
|
|
|
|
|
|
|
|
Operation const *operation = find_gemm_operation(operators_it, preference_key);
|
|
|
|
|
|
|
|
if (!operation) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
last_operation_ = operation;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Configure operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmUniversalConfiguration configuration{
|
|
|
|
mode,
|
|
|
|
{M, N, K},
|
|
|
|
batch_count,
|
|
|
|
lda,
|
|
|
|
ldb,
|
|
|
|
ldc,
|
|
|
|
ldd
|
|
|
|
};
|
|
|
|
|
|
|
|
// Query host work space size
|
|
|
|
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
char host_workspace[kHostWorkspaceSize];
|
|
|
|
|
|
|
|
// Query device workspace size
|
|
|
|
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialize host and device workspaces
|
|
|
|
Status status = operation->initialize(
|
|
|
|
&configuration,
|
|
|
|
host_workspace,
|
|
|
|
workspace_,
|
|
|
|
stream_);
|
|
|
|
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Run the operator
|
|
|
|
GemmUniversalArguments arguments{
|
|
|
|
ptr_A,
|
|
|
|
ptr_B,
|
|
|
|
ptr_C,
|
|
|
|
ptr_D,
|
|
|
|
alpha,
|
|
|
|
beta,
|
|
|
|
scalar_pointer_mode_,
|
|
|
|
batch_stride_A,
|
|
|
|
batch_stride_B,
|
|
|
|
batch_stride_C,
|
|
|
|
batch_stride_D
|
|
|
|
};
|
|
|
|
|
|
|
|
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2020-04-08 04:51:25 +08:00
|
|
|
/// Planar complex GEMM
|
|
|
|
Status Handle::gemm_planar_complex(
|
|
|
|
|
|
|
|
int M, /// GEMM M dimension
|
|
|
|
int N, /// GEMM N dimension
|
|
|
|
int K, /// GEMM K dimension
|
|
|
|
|
|
|
|
NumericTypeID element_compute, /// Data type of internal accumulation
|
|
|
|
|
|
|
|
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
|
|
|
|
|
|
|
void const *alpha, /// Pointer to alpha scalar
|
|
|
|
|
|
|
|
NumericTypeID element_A, /// Data type of A matrix elements
|
|
|
|
LayoutTypeID layout_A, /// Layout of A matrix
|
|
|
|
ComplexTransform transform_A, /// Complex transformation applied to A matrix
|
|
|
|
|
|
|
|
void const * ptr_A_real, /// Pointer to real part of A matrix
|
|
|
|
void const * ptr_A_imag, /// Pointer to imaginary part of A matrix
|
|
|
|
int lda_real, /// Leading dimension of real part of A matrix
|
|
|
|
int lda_imag, /// Leading dimension of imaginary part of A matrix
|
|
|
|
|
|
|
|
NumericTypeID element_B, /// Data type of B matrix elements
|
|
|
|
LayoutTypeID layout_B, /// Layout of B matrix
|
|
|
|
ComplexTransform transform_B, /// Complex transformation applied to B matrix
|
|
|
|
|
|
|
|
void const * ptr_B_real, /// Pointer to real part of B matrix
|
|
|
|
void const * ptr_B_imag, /// Pointer to imaginary part of B matrix
|
|
|
|
int ldb_real, /// Leading dimension of real part of B matrix
|
|
|
|
int ldb_imag, /// Leading dimension of imaginary part of B matrix
|
|
|
|
|
|
|
|
void const * beta, /// Pointer to beta scalar
|
|
|
|
|
|
|
|
NumericTypeID element_C, /// Data type of C and D matrix
|
|
|
|
|
|
|
|
void const * ptr_C_real, /// Pointer to real part of C matrix
|
|
|
|
void const * ptr_C_imag, /// Pointer to imaginary part of C matrix
|
|
|
|
int ldc_real, /// Leading dimension of real part of C matrix
|
|
|
|
int ldc_imag, /// Leading dimension of imaginary part of C matrix
|
|
|
|
|
|
|
|
void * ptr_D_real, /// Pointer to real part of D matrix
|
|
|
|
void * ptr_D_imag, /// Pointer to imaginary part of D matrix
|
|
|
|
int ldd_real, /// Leading dimension of real part of D matrix
|
|
|
|
int ldd_imag, /// Leading dimension of imaginary part of D matrix
|
|
|
|
|
|
|
|
int batch_count, /// Number of batched GEMMs to execute
|
|
|
|
|
|
|
|
int64_t batch_stride_A_real,
|
|
|
|
int64_t batch_stride_A_imag,
|
|
|
|
|
|
|
|
int64_t batch_stride_B_real,
|
|
|
|
int64_t batch_stride_B_imag,
|
|
|
|
|
|
|
|
int64_t batch_stride_C_real,
|
|
|
|
int64_t batch_stride_C_imag,
|
|
|
|
|
|
|
|
int64_t batch_stride_D_real,
|
|
|
|
int64_t batch_stride_D_imag
|
|
|
|
) {
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmFunctionalKey key(
|
2020-06-09 07:17:35 +08:00
|
|
|
provider_,
|
|
|
|
GemmKind::kPlanarComplex,
|
2020-04-08 04:51:25 +08:00
|
|
|
element_compute,
|
|
|
|
element_scalar,
|
|
|
|
element_A,
|
|
|
|
layout_A,
|
|
|
|
transform_A,
|
|
|
|
element_B,
|
|
|
|
layout_B,
|
|
|
|
transform_B,
|
|
|
|
element_C
|
|
|
|
);
|
|
|
|
|
2020-06-09 07:17:35 +08:00
|
|
|
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
2020-04-08 04:51:25 +08:00
|
|
|
|
2020-06-09 07:17:35 +08:00
|
|
|
if (operators_it == Singleton::get().operation_table.gemm_operations.end()) {
|
2020-04-08 04:51:25 +08:00
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (operators_it->second.empty()) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Compute the largest alignment restriction the kernel can satisfy.
|
|
|
|
//
|
|
|
|
|
|
|
|
// Maximum alignment expectation among all kernels (in units of bytes)
|
|
|
|
int const kMaximumAlignmentSize = 16;
|
|
|
|
|
|
|
|
int alignment = std::max(
|
|
|
|
gemm_problem_alignment(
|
|
|
|
M, N, K,
|
|
|
|
element_A, ptr_A_real, lda_real, batch_stride_A_real,
|
|
|
|
element_B, ptr_B_real, ldb_real, batch_stride_B_real,
|
|
|
|
element_C, ptr_C_real, ldc_real, batch_stride_C_real,
|
|
|
|
ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize
|
|
|
|
),
|
|
|
|
gemm_problem_alignment(
|
|
|
|
M, N, K,
|
|
|
|
element_A, ptr_A_imag, lda_imag, batch_stride_A_imag,
|
|
|
|
element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag,
|
|
|
|
element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag,
|
|
|
|
ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize
|
|
|
|
)
|
|
|
|
);
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the best kernel in descending order of preference.
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmPreferenceKey preference_key(compute_capability(), alignment);
|
|
|
|
|
|
|
|
Operation const *operation = find_gemm_operation(operators_it, preference_key);
|
|
|
|
|
|
|
|
if (!operation) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
last_operation_ = operation;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Configure operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmPlanarComplexConfiguration configuration{
|
|
|
|
GemmUniversalMode::kBatched,
|
|
|
|
{M, N, K},
|
|
|
|
batch_count,
|
|
|
|
lda_real,
|
|
|
|
lda_imag,
|
|
|
|
ldb_real,
|
|
|
|
ldb_imag,
|
|
|
|
ldc_real,
|
|
|
|
ldc_imag,
|
|
|
|
ldd_real,
|
|
|
|
ldd_imag
|
|
|
|
};
|
|
|
|
|
|
|
|
// Query host work space size
|
|
|
|
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
char host_workspace[kHostWorkspaceSize];
|
|
|
|
|
|
|
|
// Query device workspace size
|
|
|
|
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialize host and device workspaces
|
|
|
|
Status status = operation->initialize(
|
|
|
|
&configuration,
|
|
|
|
host_workspace,
|
|
|
|
workspace_,
|
|
|
|
stream_);
|
|
|
|
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Run the operator
|
|
|
|
GemmPlanarComplexArguments arguments{
|
|
|
|
ptr_A_real,
|
|
|
|
ptr_A_imag,
|
|
|
|
ptr_B_real,
|
|
|
|
ptr_B_imag,
|
|
|
|
ptr_C_real,
|
|
|
|
ptr_C_imag,
|
|
|
|
ptr_D_real,
|
|
|
|
ptr_D_imag,
|
|
|
|
alpha,
|
|
|
|
beta,
|
|
|
|
scalar_pointer_mode_,
|
|
|
|
batch_stride_A_real,
|
|
|
|
batch_stride_A_imag,
|
|
|
|
batch_stride_B_real,
|
|
|
|
batch_stride_B_imag,
|
|
|
|
batch_stride_C_real,
|
|
|
|
batch_stride_C_imag,
|
|
|
|
batch_stride_D_real,
|
|
|
|
batch_stride_D_imag
|
|
|
|
};
|
|
|
|
|
|
|
|
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Planar complex batched GEMM loading pointers from arrays in global memory
|
|
|
|
Status Handle::gemm_planar_complex_array(
|
|
|
|
|
|
|
|
int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid)
|
|
|
|
int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid)
|
|
|
|
int expected_K, /// Expected GEMM K dimension
|
|
|
|
int batch_count, /// Number of independent GEMM computations to execute
|
|
|
|
|
|
|
|
int const *M, /// Array containing the GEMM M dimension for each batch index
|
|
|
|
int const *N, /// Array containing the GEMM N dimension for each batch index
|
|
|
|
int const *K, /// Array containing the GEMM K dimension for each batch index
|
|
|
|
|
|
|
|
NumericTypeID element_compute, /// Data type of internal accumulation
|
|
|
|
|
|
|
|
NumericTypeID element_scalar, /// Data type of alpha/beta scalars
|
|
|
|
|
|
|
|
void const *alpha, /// Pointer to alpha scalar
|
|
|
|
|
|
|
|
NumericTypeID element_A, /// Data type of A matrix elements
|
|
|
|
LayoutTypeID layout_A, /// Layout of A matrix
|
|
|
|
ComplexTransform transform_A, /// Complex transformation applied to A matrix
|
|
|
|
|
|
|
|
void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices
|
|
|
|
void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices
|
|
|
|
|
|
|
|
int lda_real, /// Leading dimension of real part of A matrix
|
|
|
|
int lda_imag, /// Leading dimension of imaginary part of A matrix
|
|
|
|
|
|
|
|
NumericTypeID element_B, /// Data type of B matrix elements
|
|
|
|
LayoutTypeID layout_B, /// Layout of B matrix
|
|
|
|
ComplexTransform transform_B, /// Complex transformation applied to B matrix
|
|
|
|
|
|
|
|
void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices
|
|
|
|
void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices
|
|
|
|
|
|
|
|
int ldb_real, /// Leading dimension of real part of B matrix
|
|
|
|
int ldb_imag, /// Leading dimension of imaginary part of B matrix
|
|
|
|
|
|
|
|
void const * beta, /// Pointer to beta scalar
|
|
|
|
|
|
|
|
NumericTypeID element_C, /// Data type of C and D matrix
|
|
|
|
|
|
|
|
void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices
|
|
|
|
void const * const * ptr_C_imag, /// Pointer to array containing poitners to imaginary part of C matrices
|
|
|
|
|
|
|
|
int ldc_real, /// Leading dimension of real part of C matrix
|
|
|
|
int ldc_imag, /// Leading dimension of imaginary part of C matrix
|
|
|
|
|
|
|
|
void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices
|
|
|
|
void * const * ptr_D_imag, /// Pointer to array containing poitners to imaginary part of D matrices
|
|
|
|
|
|
|
|
int ldd_real, /// Leading dimension of real part of D matrix
|
|
|
|
int ldd_imag /// Leading dimension of imaginary part of D matrix
|
|
|
|
) {
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmFunctionalKey key(
|
2020-06-09 07:17:35 +08:00
|
|
|
provider_,
|
|
|
|
GemmKind::kPlanarComplexArray,
|
2020-04-08 04:51:25 +08:00
|
|
|
element_compute,
|
|
|
|
element_scalar,
|
|
|
|
element_A,
|
|
|
|
layout_A,
|
|
|
|
transform_A,
|
|
|
|
element_B,
|
|
|
|
layout_B,
|
|
|
|
transform_B,
|
|
|
|
element_C
|
|
|
|
);
|
|
|
|
|
2020-06-09 07:17:35 +08:00
|
|
|
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
2020-04-08 04:51:25 +08:00
|
|
|
|
2020-06-09 07:17:35 +08:00
|
|
|
if (operators_it == Singleton::get().operation_table.gemm_operations.end()) {
|
2020-04-08 04:51:25 +08:00
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (operators_it->second.empty()) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Compute the largest alignment restriction the kernel can satisfy.
|
|
|
|
//
|
|
|
|
|
|
|
|
// Maximum alignment expectation among all kernels (in units of bytes)
|
|
|
|
int const kMaximumAlignmentSize = 16;
|
|
|
|
|
|
|
|
int alignment = std::max(
|
|
|
|
gemm_problem_alignment(
|
|
|
|
expected_M, expected_N, expected_K,
|
|
|
|
element_A, nullptr, lda_real, 0,
|
|
|
|
element_B, nullptr, ldb_real, 0,
|
|
|
|
element_C, nullptr, ldc_real, 0,
|
|
|
|
nullptr, ldd_real, 0, kMaximumAlignmentSize
|
|
|
|
),
|
|
|
|
gemm_problem_alignment(
|
|
|
|
expected_M, expected_N, expected_K,
|
|
|
|
element_A, nullptr, lda_imag, 0,
|
|
|
|
element_B, nullptr, ldb_imag, 0,
|
|
|
|
element_C, nullptr, ldc_imag, 0,
|
|
|
|
nullptr, ldd_imag, 0, kMaximumAlignmentSize
|
|
|
|
)
|
|
|
|
);
|
|
|
|
|
|
|
|
//
|
|
|
|
// Find the best kernel in descending order of preference.
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmPreferenceKey preference_key(compute_capability(), alignment);
|
|
|
|
|
|
|
|
Operation const *operation = find_gemm_operation(operators_it, preference_key);
|
|
|
|
|
|
|
|
if (!operation) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
last_operation_ = operation;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Configure operation
|
|
|
|
//
|
|
|
|
|
|
|
|
GemmPlanarComplexArrayConfiguration configuration{
|
|
|
|
{expected_M, expected_N, expected_K},
|
|
|
|
batch_count,
|
|
|
|
lda_real,
|
|
|
|
lda_imag,
|
|
|
|
ldb_real,
|
|
|
|
ldb_imag,
|
|
|
|
ldc_real,
|
|
|
|
ldc_imag,
|
|
|
|
ldd_real,
|
|
|
|
ldd_imag
|
|
|
|
};
|
|
|
|
|
|
|
|
// Query host work space size
|
|
|
|
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
char host_workspace[kHostWorkspaceSize];
|
|
|
|
|
|
|
|
// Query device workspace size
|
|
|
|
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
|
|
|
|
|
|
|
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
|
|
|
return cutlass::Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialize host and device workspaces
|
|
|
|
Status status = operation->initialize(
|
|
|
|
&configuration,
|
|
|
|
host_workspace,
|
|
|
|
workspace_,
|
|
|
|
stream_);
|
|
|
|
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Run the operator
|
|
|
|
GemmPlanarComplexArrayArguments arguments{
|
|
|
|
M, N, K,
|
|
|
|
ptr_A_real,
|
|
|
|
ptr_A_imag,
|
|
|
|
ptr_B_real,
|
|
|
|
ptr_B_imag,
|
|
|
|
ptr_C_real,
|
|
|
|
ptr_C_imag,
|
|
|
|
ptr_D_real,
|
|
|
|
ptr_D_imag,
|
|
|
|
alpha,
|
|
|
|
beta,
|
|
|
|
scalar_pointer_mode_
|
|
|
|
};
|
|
|
|
|
|
|
|
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
2020-11-20 13:25:25 +08:00
|
|
|
|
|
|
|
/// Finds conv operation instances with Conv::ElementC = Reduction::ElementWorkspace
|
|
|
|
Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation) {
|
|
|
|
|
|
|
|
ConvDescription const &conv_desc =
|
|
|
|
static_cast<ConvDescription const &>(operation->description());
|
|
|
|
|
|
|
|
// if the curren conv operation accumulator and output data type match return operation
|
|
|
|
if(conv_desc.tile_description.math_instruction.element_accumulator == conv_desc.C.element) {
|
|
|
|
return operation;
|
|
|
|
}
|
|
|
|
|
|
|
|
// find conv operation to match conv output and reduction workspace data type
|
|
|
|
ConvFunctionalKey key(
|
|
|
|
library::Provider::kCUTLASS,
|
|
|
|
conv_desc.conv_kind,
|
|
|
|
conv_desc.A.element,
|
|
|
|
conv_desc.A.layout,
|
|
|
|
conv_desc.B.element,
|
|
|
|
conv_desc.B.layout,
|
|
|
|
conv_desc.tile_description.math_instruction.element_accumulator,
|
|
|
|
conv_desc.C.layout,
|
|
|
|
conv_desc.tile_description.math_instruction.element_accumulator,
|
|
|
|
conv_desc.element_epilogue);
|
|
|
|
|
|
|
|
// conv operation table for conv2d or conv3d
|
|
|
|
auto conv_operations = (conv_desc.kind == OperationKind::kConv2d) ?
|
|
|
|
Singleton::get().operation_table.conv2d_operations :
|
|
|
|
Singleton::get().operation_table.conv3d_operations;
|
|
|
|
|
|
|
|
// find ConvFunctionalKey in convolution operation table
|
|
|
|
auto operators_it = conv_operations.find(key);
|
|
|
|
|
|
|
|
if (operators_it == conv_operations.end()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (operators_it->second.empty()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// conv operation for same compute capability and iterator algorithm
|
|
|
|
ConvPreferenceKey preference_key(
|
|
|
|
conv_desc.tile_description.minimum_compute_capability,
|
|
|
|
conv_desc.iterator_algorithm);
|
|
|
|
|
|
|
|
auto it = operators_it->second.find(preference_key);
|
|
|
|
|
|
|
|
if(it == operators_it->second.end()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// return matching conv opertion (same tile sizes and instruction)
|
|
|
|
for (auto op : it->second) {
|
|
|
|
if (op->description().tile_description == operation->description().tile_description) {
|
|
|
|
return op;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2020-04-08 04:51:25 +08:00
|
|
|
} // namespace library
|
|
|
|
} // namespace cutlass
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|