1600 lines
43 KiB
Plaintext
1600 lines
43 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
#include <iosfwd>
|
|
#include <complex>
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/complex.h"
|
|
#include "cutlass/blas3.h"
|
|
|
|
#include "cutlass/layout/matrix.h"
|
|
|
|
#include "cutlass/library/library.h"
|
|
#include "cutlass/library/util.h"
|
|
|
|
namespace cutlass {
|
|
namespace library {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
Provider enumerant;
|
|
}
|
|
Provider_enumerants[] = {
|
|
{"none", "None", Provider::kNone},
|
|
{"cutlass", "CUTLASS", Provider::kCUTLASS},
|
|
{"host", "reference_host", Provider::kReferenceHost},
|
|
{"device", "reference_device", Provider::kReferenceDevice},
|
|
{"cublas", "cuBLAS", Provider::kCUBLAS},
|
|
{"cudnn", "cuDNN", Provider::kCUDNN},
|
|
};
|
|
|
|
/// Converts a Provider enumerant to a string
|
|
char const *to_string(Provider provider, bool pretty) {
|
|
|
|
for (auto const & possible : Provider_enumerants) {
|
|
if (provider == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Parses a Provider enumerant from a string
|
|
template <>
|
|
Provider from_string<Provider>(std::string const &str) {
|
|
|
|
for (auto const & possible : Provider_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return Provider::kInvalid;
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
GemmKind enumerant;
|
|
}
|
|
GemmKind_enumerants[] = {
|
|
{"gemm", "<Gemm>", GemmKind::kGemm},
|
|
{"spgemm", "<Sparse>", GemmKind::kSparse},
|
|
{"universal", "<Universal>", GemmKind::kUniversal},
|
|
{"planar_complex", "<PlanarComplex>", GemmKind::kPlanarComplex},
|
|
{"planar_complex_array", "<PlanarComplexArray>", GemmKind::kPlanarComplexArray},
|
|
{"grouped", "<Grouped>", GemmKind::kGrouped},
|
|
};
|
|
|
|
/// Converts a GemmKind enumerant to a string
|
|
char const *to_string(GemmKind type, bool pretty) {
|
|
|
|
for (auto const & possible : GemmKind_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
RankKKind enumerant;
|
|
}
|
|
RankKKind_enumerants[] = {
|
|
{"universal", "<Universal>", RankKKind::kUniversal},
|
|
};
|
|
|
|
/// Converts a SyrkKind enumerant to a string
|
|
char const *to_string(RankKKind type, bool pretty) {
|
|
|
|
for (auto const & possible :RankKKind_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
TrmmKind enumerant;
|
|
}
|
|
TrmmKind_enumerants[] = {
|
|
{"universal", "<Universal>", TrmmKind::kUniversal},
|
|
};
|
|
|
|
/// Converts a TrmmKind enumerant to a string
|
|
char const *to_string(TrmmKind type, bool pretty) {
|
|
|
|
for (auto const & possible :TrmmKind_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
SymmKind enumerant;
|
|
}
|
|
SymmKind_enumerants[] = {
|
|
{"universal", "<Universal>", SymmKind::kUniversal},
|
|
};
|
|
|
|
/// Converts a SymmKind enumerant to a string
|
|
char const *to_string(SymmKind type, bool pretty) {
|
|
|
|
for (auto const & possible :SymmKind_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
SideMode enumerant;
|
|
}
|
|
SideMode_enumerants[] = {
|
|
{"left", "Left", SideMode::kLeft},
|
|
{"right", "Right", SideMode::kRight}
|
|
};
|
|
|
|
/// Converts a SideMode enumerant to a string
|
|
char const *to_string(SideMode type, bool pretty) {
|
|
|
|
for (auto const & possible :SideMode_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
FillMode enumerant;
|
|
}
|
|
FillMode_enumerants[] = {
|
|
{"lower", "Lower", FillMode::kLower},
|
|
{"upper", "Upper", FillMode::kUpper}
|
|
};
|
|
|
|
/// Converts a FillMode enumerant to a string
|
|
char const *to_string(FillMode type, bool pretty) {
|
|
|
|
for (auto const & possible :FillMode_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
BlasMode enumerant;
|
|
}
|
|
BlasMode_enumerants[] = {
|
|
{"symmetric", "Symmetric", BlasMode::kSymmetric},
|
|
{"hermitian", "Hermitian", BlasMode::kHermitian}
|
|
};
|
|
|
|
/// Converts a BlasMode enumerant to a string
|
|
char const *to_string(BlasMode type, bool pretty) {
|
|
|
|
for (auto const & possible :BlasMode_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
DiagType enumerant;
|
|
}
|
|
DiagType_enumerants[] = {
|
|
{"nonunit", "NonUnit", DiagType::kNonUnit},
|
|
{"unit", "Unit", DiagType::kUnit}
|
|
};
|
|
|
|
/// Converts a DiagType enumerant to a string
|
|
char const *to_string(DiagType type, bool pretty) {
|
|
|
|
for (auto const & possible :DiagType_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
OperationKind enumerant;
|
|
}
|
|
OperationKind_enumerants[] = {
|
|
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
|
|
{"gemm", "Gemm", OperationKind::kGemm},
|
|
{"rank_k", "RankK", OperationKind::kRankK},
|
|
{"rank_2k", "Rank2K", OperationKind::kRank2K},
|
|
{"trmm", "Trmm", OperationKind::kTrmm},
|
|
{"symm", "Symm", OperationKind::kSymm},
|
|
{"conv2d", "Conv2d", OperationKind::kConv2d},
|
|
{"conv3d", "Conv3d", OperationKind::kConv3d},
|
|
{"spgemm", "SparseGemm", OperationKind::kSparseGemm},
|
|
};
|
|
|
|
/// Converts a Status enumerant to a string
|
|
char const *to_string(OperationKind enumerant, bool pretty) {
|
|
|
|
for (auto const & possible : OperationKind_enumerants) {
|
|
if (enumerant == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Converts a Status enumerant from a string
|
|
template <>
|
|
OperationKind from_string<OperationKind>(std::string const &str) {
|
|
|
|
for (auto const & possible : OperationKind_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return OperationKind::kInvalid;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
Status enumerant;
|
|
}
|
|
Status_enumerants[] = {
|
|
{"success", "Success", Status::kSuccess},
|
|
{"misaligned_operand", "Error: misaligned operand", Status::kErrorMisalignedOperand},
|
|
{"invalid_problem", "Error: invalid problem", Status::kErrorInvalidProblem},
|
|
{"not_supported", "Error: not supported", Status::kErrorNotSupported},
|
|
{"internal", "Error: internal", Status::kErrorInternal}
|
|
};
|
|
|
|
/// Converts a Status enumerant to a string
|
|
char const *to_string(Status status, bool pretty) {
|
|
|
|
for (auto const & possible : Status_enumerants) {
|
|
if (status == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Converts a Status enumerant from a string
|
|
template <>
|
|
Status from_string<Status>(std::string const &str) {
|
|
|
|
for (auto const & possible : Status_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return Status::kInvalid;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
NumericTypeID enumerant;
|
|
}
|
|
NumericTypeID_enumerants[] = {
|
|
{"unknown", "<unkown>", NumericTypeID::kUnknown},
|
|
{"void", "Void", NumericTypeID::kVoid},
|
|
{"b1", "B1", NumericTypeID::kB1},
|
|
{"u2", "U2", NumericTypeID::kU2},
|
|
{"u4", "U4", NumericTypeID::kU4},
|
|
{"u8", "U8", NumericTypeID::kU8},
|
|
{"u16", "U16", NumericTypeID::kU16},
|
|
{"u32", "U32", NumericTypeID::kU32},
|
|
{"u64", "U64", NumericTypeID::kU64},
|
|
{"s2", "S2", NumericTypeID::kS2},
|
|
{"s4", "S4", NumericTypeID::kS4},
|
|
{"s8", "S8", NumericTypeID::kS8},
|
|
{"s16", "S16", NumericTypeID::kS16},
|
|
{"s32", "S32", NumericTypeID::kS32},
|
|
{"s64", "S64", NumericTypeID::kS64},
|
|
{"f16", "F16", NumericTypeID::kF16},
|
|
{"bf16", "BF16", NumericTypeID::kBF16},
|
|
{"f32", "F32", NumericTypeID::kF32},
|
|
{"tf32", "TF32", NumericTypeID::kTF32},
|
|
{"f64", "F64", NumericTypeID::kF64},
|
|
{"cf16", "CF16", NumericTypeID::kCF16},
|
|
{"cbf16", "CBF16", NumericTypeID::kCBF16},
|
|
{"cf32", "CF32", NumericTypeID::kCF32},
|
|
{"ctf32", "CTF32", NumericTypeID::kCTF32},
|
|
{"cf64", "CF64", NumericTypeID::kCF64},
|
|
{"cu2", "CU2", NumericTypeID::kCU2},
|
|
{"cu4", "CU4", NumericTypeID::kCU4},
|
|
{"cu8", "CU8", NumericTypeID::kCU8},
|
|
{"cu16", "CU16", NumericTypeID::kCU16},
|
|
{"cu32", "CU32", NumericTypeID::kCU32},
|
|
{"cu64", "CU64", NumericTypeID::kCU64},
|
|
{"cs2", "CS2", NumericTypeID::kCS2},
|
|
{"cs4", "CS4", NumericTypeID::kCS4},
|
|
{"cs8", "CS8", NumericTypeID::kCS8},
|
|
{"cs16", "CS16", NumericTypeID::kCS16},
|
|
{"cs32", "CS32", NumericTypeID::kCS32},
|
|
{"cs64", "CS64", NumericTypeID::kCS64},
|
|
{"*", "<unkown/enumerate all>", NumericTypeID::kUnknown}
|
|
};
|
|
|
|
/// Converts a NumericTypeID enumerant to a string
|
|
char const *to_string(NumericTypeID type, bool pretty) {
|
|
|
|
for (auto const & possible : NumericTypeID_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Parses a NumericTypeID enumerant from a string
|
|
template <>
|
|
NumericTypeID from_string<NumericTypeID>(std::string const &str) {
|
|
|
|
for (auto const & possible : NumericTypeID_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return NumericTypeID::kInvalid;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Returns the size of a data type in bits
|
|
int sizeof_bits(NumericTypeID type) {
|
|
switch (type) {
|
|
case NumericTypeID::kF16: return 16;
|
|
case NumericTypeID::kBF16: return 16;
|
|
case NumericTypeID::kTF32: return 32;
|
|
case NumericTypeID::kF32: return 32;
|
|
case NumericTypeID::kF64: return 64;
|
|
case NumericTypeID::kCF16: return 32;
|
|
case NumericTypeID::kCBF16: return 32;
|
|
case NumericTypeID::kCF32: return 64;
|
|
case NumericTypeID::kCTF32: return 64;
|
|
case NumericTypeID::kCF64: return 128;
|
|
case NumericTypeID::kS2: return 2;
|
|
case NumericTypeID::kS4: return 4;
|
|
case NumericTypeID::kS8: return 8;
|
|
case NumericTypeID::kS16: return 16;
|
|
case NumericTypeID::kS32: return 32;
|
|
case NumericTypeID::kS64: return 64;
|
|
case NumericTypeID::kU2: return 2;
|
|
case NumericTypeID::kU4: return 4;
|
|
case NumericTypeID::kU8: return 8;
|
|
case NumericTypeID::kU16: return 16;
|
|
case NumericTypeID::kU32: return 32;
|
|
case NumericTypeID::kU64: return 64;
|
|
case NumericTypeID::kB1: return 1;
|
|
default: break;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
/// Returns true if the numeric type is a complex data type or false if real-valued.
|
|
bool is_complex_type(NumericTypeID type) {
|
|
switch (type) {
|
|
case NumericTypeID::kCF16: return true;
|
|
case NumericTypeID::kCF32: return true;
|
|
case NumericTypeID::kCF64: return true;
|
|
case NumericTypeID::kCBF16: return true;
|
|
case NumericTypeID::kCTF32: return true;
|
|
default: break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Returns the field underlying a complex valued type
|
|
NumericTypeID get_real_type(NumericTypeID type) {
|
|
switch (type) {
|
|
case NumericTypeID::kCF16: return NumericTypeID::kF16;
|
|
case NumericTypeID::kCF32: return NumericTypeID::kF32;
|
|
case NumericTypeID::kCF64: return NumericTypeID::kF64;
|
|
case NumericTypeID::kCBF16: return NumericTypeID::kBF16;
|
|
case NumericTypeID::kCTF32: return NumericTypeID::kTF32;
|
|
default: break;
|
|
}
|
|
return type;
|
|
}
|
|
|
|
/// Returns true if numeric type is integer
|
|
bool is_integer_type(NumericTypeID type) {
|
|
switch (type) {
|
|
case NumericTypeID::kS2: return true;
|
|
case NumericTypeID::kS4: return true;
|
|
case NumericTypeID::kS8: return true;
|
|
case NumericTypeID::kS16: return true;
|
|
case NumericTypeID::kS32: return true;
|
|
case NumericTypeID::kS64: return true;
|
|
case NumericTypeID::kU2: return true;
|
|
case NumericTypeID::kU4: return true;
|
|
case NumericTypeID::kU8: return true;
|
|
case NumericTypeID::kU16: return true;
|
|
case NumericTypeID::kU32: return true;
|
|
case NumericTypeID::kU64: return true;
|
|
default: break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Returns true if numeric type is signed
|
|
bool is_signed_type(NumericTypeID type) {
|
|
switch (type) {
|
|
case NumericTypeID::kF16: return true;
|
|
case NumericTypeID::kBF16: return true;
|
|
case NumericTypeID::kTF32: return true;
|
|
case NumericTypeID::kF32: return true;
|
|
case NumericTypeID::kF64: return true;
|
|
case NumericTypeID::kS2: return true;
|
|
case NumericTypeID::kS4: return true;
|
|
case NumericTypeID::kS8: return true;
|
|
case NumericTypeID::kS16: return true;
|
|
case NumericTypeID::kS32: return true;
|
|
case NumericTypeID::kS64: return true;
|
|
default: break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Returns true if numeric type is a signed integer
|
|
bool is_signed_integer(NumericTypeID type) {
|
|
return is_integer_type(type) && is_signed_type(type);
|
|
}
|
|
|
|
/// returns true if numeric type is an unsigned integer
|
|
bool is_unsigned_integer(NumericTypeID type) {
|
|
return is_integer_type(type) && !is_signed_type(type);
|
|
}
|
|
|
|
/// Returns true if numeric type is floating-point type
|
|
bool is_float_type(NumericTypeID type) {
|
|
switch (type) {
|
|
case NumericTypeID::kF16: return true;
|
|
case NumericTypeID::kBF16: return true;
|
|
case NumericTypeID::kTF32: return true;
|
|
case NumericTypeID::kF32: return true;
|
|
case NumericTypeID::kF64: return true;
|
|
case NumericTypeID::kCF16: return true;
|
|
case NumericTypeID::kCBF16: return true;
|
|
case NumericTypeID::kCTF32: return true;
|
|
case NumericTypeID::kCF32: return true;
|
|
case NumericTypeID::kCF64: return true;
|
|
default: break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
LayoutTypeID layout;
|
|
char const *alias;
|
|
}
|
|
layout_aliases[] = {
|
|
{LayoutTypeID::kUnknown, "unknown"},
|
|
{LayoutTypeID::kRowMajor, "row"},
|
|
{LayoutTypeID::kRowMajor, "t"},
|
|
{LayoutTypeID::kColumnMajor, "column"},
|
|
{LayoutTypeID::kColumnMajor, "col"},
|
|
{LayoutTypeID::kColumnMajor, "n"},
|
|
|
|
{LayoutTypeID::kColumnMajorInterleavedK2, "nk2"},
|
|
{LayoutTypeID::kRowMajorInterleavedK2, "tk2"},
|
|
|
|
{LayoutTypeID::kColumnMajorInterleavedK4, "nk4"},
|
|
{LayoutTypeID::kRowMajorInterleavedK4, "tk4"},
|
|
|
|
{LayoutTypeID::kColumnMajorInterleavedK16, "nk16"},
|
|
{LayoutTypeID::kRowMajorInterleavedK16, "tk16"},
|
|
|
|
{LayoutTypeID::kColumnMajorInterleavedK32, "nk32"},
|
|
{LayoutTypeID::kRowMajorInterleavedK32, "tk32"},
|
|
|
|
{LayoutTypeID::kColumnMajorInterleavedK64, "nk64"},
|
|
{LayoutTypeID::kRowMajorInterleavedK64, "tk64"},
|
|
|
|
{LayoutTypeID::kTensorNCHW, "nchw"},
|
|
{LayoutTypeID::kTensorNCDHW, "ncdhw"},
|
|
{LayoutTypeID::kTensorNHWC, "nhwc"},
|
|
{LayoutTypeID::kTensorNDHWC, "ndhwc"},
|
|
{LayoutTypeID::kTensorNC32HW32, "nc32hw32"},
|
|
{LayoutTypeID::kTensorNC64HW64, "nc64hw64"},
|
|
{LayoutTypeID::kTensorC32RSK32, "c32rsk32"},
|
|
{LayoutTypeID::kTensorC64RSK64, "c64rsk64"},
|
|
|
|
{LayoutTypeID::kUnknown, "*"},
|
|
{LayoutTypeID::kInvalid, nullptr}
|
|
};
|
|
|
|
/// Converts a LayoutTypeID enumerant to a string
|
|
char const *to_string(LayoutTypeID layout, bool pretty) {
|
|
for (auto const & alias : layout_aliases) {
|
|
if (alias.layout == layout) {
|
|
return alias.alias;
|
|
}
|
|
}
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Parses a LayoutTypeID enumerant from a string
|
|
template <>
|
|
LayoutTypeID from_string<LayoutTypeID>(std::string const &str) {
|
|
for (auto const & alias : layout_aliases) {
|
|
if (str.compare(alias.alias) == 0) {
|
|
return alias.layout;
|
|
}
|
|
}
|
|
return LayoutTypeID::kInvalid;
|
|
}
|
|
|
|
/// Gets stride rank for the layout_id (static function)
|
|
int get_layout_stride_rank(LayoutTypeID layout_id) {
|
|
switch (layout_id) {
|
|
case LayoutTypeID::kColumnMajor:
|
|
return cutlass::layout::ColumnMajor::kStrideRank;
|
|
case LayoutTypeID::kRowMajor:
|
|
return cutlass::layout::RowMajor::kStrideRank;
|
|
case LayoutTypeID::kColumnMajorInterleavedK2:
|
|
return cutlass::layout::ColumnMajorInterleaved<2>::kStrideRank;
|
|
case LayoutTypeID::kRowMajorInterleavedK2:
|
|
return cutlass::layout::RowMajorInterleaved<2>::kStrideRank;
|
|
case LayoutTypeID::kColumnMajorInterleavedK4:
|
|
return cutlass::layout::ColumnMajorInterleaved<4>::kStrideRank;
|
|
case LayoutTypeID::kRowMajorInterleavedK4:
|
|
return cutlass::layout::RowMajorInterleaved<4>::kStrideRank;
|
|
case LayoutTypeID::kColumnMajorInterleavedK16:
|
|
return cutlass::layout::ColumnMajorInterleaved<16>::kStrideRank;
|
|
case LayoutTypeID::kRowMajorInterleavedK16:
|
|
return cutlass::layout::RowMajorInterleaved<16>::kStrideRank;
|
|
case LayoutTypeID::kColumnMajorInterleavedK32:
|
|
return cutlass::layout::ColumnMajorInterleaved<32>::kStrideRank;
|
|
case LayoutTypeID::kRowMajorInterleavedK32:
|
|
return cutlass::layout::RowMajorInterleaved<32>::kStrideRank;
|
|
case LayoutTypeID::kColumnMajorInterleavedK64:
|
|
return cutlass::layout::ColumnMajorInterleaved<64>::kStrideRank;
|
|
case LayoutTypeID::kRowMajorInterleavedK64:
|
|
return cutlass::layout::RowMajorInterleaved<64>::kStrideRank;
|
|
case LayoutTypeID::kTensorNCHW:
|
|
return cutlass::layout::TensorNCHW::kStrideRank;
|
|
case LayoutTypeID::kTensorNHWC:
|
|
return cutlass::layout::TensorNHWC::kStrideRank;
|
|
case LayoutTypeID::kTensorNDHWC:
|
|
return cutlass::layout::TensorNDHWC::kStrideRank;
|
|
case LayoutTypeID::kTensorNC32HW32:
|
|
return cutlass::layout::TensorNCxHWx<32>::kStrideRank;
|
|
case LayoutTypeID::kTensorNC64HW64:
|
|
return cutlass::layout::TensorNCxHWx<64>::kStrideRank;
|
|
case LayoutTypeID::kTensorC32RSK32:
|
|
return cutlass::layout::TensorCxRSKx<32>::kStrideRank;
|
|
case LayoutTypeID::kTensorC64RSK64:
|
|
return cutlass::layout::TensorCxRSKx<64>::kStrideRank;
|
|
default:
|
|
throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank");
|
|
}
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
OpcodeClassID enumerant;
|
|
}
|
|
OpcodeClassID_enumerants[] = {
|
|
{"simt", "<simt>", OpcodeClassID::kSimt},
|
|
{"tensorop", "<tensorop>", OpcodeClassID::kTensorOp},
|
|
{"wmmatensorop", "<wmmatensorop>", OpcodeClassID::kWmmaTensorOp},
|
|
{"wmma", "<wmma>", OpcodeClassID::kWmmaTensorOp},
|
|
};
|
|
|
|
/// Converts a OpcodeClassID enumerant to a string
|
|
char const *to_string(OpcodeClassID type, bool pretty) {
|
|
|
|
for (auto const & possible : OpcodeClassID_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Converts a OpcodeClassID enumerant from a string
|
|
template <>
|
|
OpcodeClassID from_string<OpcodeClassID>(std::string const &str) {
|
|
|
|
for (auto const & possible : OpcodeClassID_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return OpcodeClassID::kInvalid;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
ComplexTransform enumerant;
|
|
}
|
|
ComplexTransform_enumerants[] = {
|
|
{"n", "none", ComplexTransform::kNone},
|
|
{"c", "conj", ComplexTransform::kConjugate}
|
|
};
|
|
|
|
/// Converts a ComplexTransform enumerant to a string
|
|
char const *to_string(ComplexTransform type, bool pretty) {
|
|
|
|
for (auto const & possible : ComplexTransform_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Converts a ComplexTransform enumerant from a string
|
|
template <>
|
|
ComplexTransform from_string<ComplexTransform>(std::string const &str) {
|
|
|
|
for (auto const & possible : ComplexTransform_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return ComplexTransform::kInvalid;
|
|
}
|
|
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
SplitKMode enumerant;
|
|
}
|
|
SplitKMode_enumerants[] = {
|
|
{"serial", "<serial>", SplitKMode::kSerial},
|
|
{"parallel", "<parallel>", SplitKMode::kParallel},
|
|
};
|
|
|
|
/// Converts a SplitKMode enumerant to a string
|
|
char const *to_string(SplitKMode type, bool pretty) {
|
|
|
|
for (auto const & possible : SplitKMode_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Converts a SplitKMode enumerant from a string
|
|
template <>
|
|
SplitKMode from_string<SplitKMode>(std::string const &str) {
|
|
|
|
for (auto const & possible : SplitKMode_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return SplitKMode::kInvalid;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
ConvModeID enumerant;
|
|
}
|
|
ConvModeID_enumerants[] = {
|
|
{"cross", "<cross>", ConvModeID::kCrossCorrelation},
|
|
{"conv", "<conv>", ConvModeID::kConvolution},
|
|
};
|
|
|
|
/// Converts a ConvModeID enumerant to a string
|
|
char const *to_string(ConvModeID type, bool pretty) {
|
|
|
|
for (auto const & possible : ConvModeID_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Converts a ConvModeID enumerant from a string
|
|
template <>
|
|
ConvModeID from_string<ConvModeID>(std::string const &str) {
|
|
|
|
for (auto const & possible : ConvModeID_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return ConvModeID::kInvalid;
|
|
}
|
|
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
IteratorAlgorithmID enumerant;
|
|
}
|
|
IteratorAlgorithmID_enumerants[] = {
|
|
{"none", "<none>", IteratorAlgorithmID::kNone},
|
|
{"analytic", "<analytic>", IteratorAlgorithmID::kAnalytic},
|
|
{"optimized", "<optimized>", IteratorAlgorithmID::kOptimized},
|
|
{"fixed_channels", "<fixed_channels>", IteratorAlgorithmID::kFixedChannels},
|
|
{"few_channels", "<few_channels>", IteratorAlgorithmID::kFewChannels},
|
|
};
|
|
|
|
/// Converts a ConvModeID enumerant to a string
|
|
char const *to_string(IteratorAlgorithmID type, bool pretty) {
|
|
|
|
for (auto const & possible : IteratorAlgorithmID_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
/// Converts a ConvModeID enumerant from a string
|
|
template <>
|
|
IteratorAlgorithmID from_string<IteratorAlgorithmID>(std::string const &str) {
|
|
|
|
for (auto const & possible : IteratorAlgorithmID_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return IteratorAlgorithmID::kInvalid;
|
|
}
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static struct {
|
|
char const *text;
|
|
char const *pretty;
|
|
ConvKind enumerant;
|
|
}
|
|
ConvKind_enumerants[] = {
|
|
{"unknown", "<unkown>", ConvKind::kUnknown},
|
|
{"fprop", "<fprop>", ConvKind::kFprop},
|
|
{"dgrad", "<dgrad>", ConvKind::kDgrad},
|
|
{"wgrad", "<wgrad>", ConvKind::kWgrad},
|
|
};
|
|
|
|
/// Converts a ConvKind enumerant to a string
|
|
char const *to_string(ConvKind type, bool pretty) {
|
|
|
|
for (auto const & possible : ConvKind_enumerants) {
|
|
if (type == possible.enumerant) {
|
|
if (pretty) {
|
|
return possible.pretty;
|
|
}
|
|
else {
|
|
return possible.text;
|
|
}
|
|
}
|
|
}
|
|
|
|
return pretty ? "Invalid" : "invalid";
|
|
}
|
|
|
|
|
|
/// Converts a ConvKind enumerant from a string
|
|
template <>
|
|
ConvKind from_string<ConvKind>(std::string const &str) {
|
|
|
|
for (auto const & possible : ConvKind_enumerants) {
|
|
if ((str.compare(possible.text) == 0) ||
|
|
(str.compare(possible.pretty) == 0)) {
|
|
return possible.enumerant;
|
|
}
|
|
}
|
|
|
|
return ConvKind::kInvalid;
|
|
}
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
|
|
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str) {
|
|
int size_bytes = sizeof_bits(type) / 8;
|
|
if (!size_bytes) {
|
|
return false;
|
|
}
|
|
|
|
bytes.resize(size_bytes, 0);
|
|
|
|
std::stringstream ss;
|
|
ss << str;
|
|
|
|
switch (type) {
|
|
case NumericTypeID::kU8:
|
|
{
|
|
ss >> *reinterpret_cast<uint8_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kU16:
|
|
{
|
|
ss >> *reinterpret_cast<uint16_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kU32:
|
|
{
|
|
ss >> *reinterpret_cast<uint32_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kU64:
|
|
{
|
|
ss >> *reinterpret_cast<uint64_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS8:
|
|
{
|
|
ss >> *reinterpret_cast<int8_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS16:
|
|
{
|
|
ss >> *reinterpret_cast<int16_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS32:
|
|
{
|
|
ss >> *reinterpret_cast<int32_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS64:
|
|
{
|
|
ss >> *reinterpret_cast<int64_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kF16:
|
|
{
|
|
float tmp;
|
|
ss >> tmp;
|
|
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(tmp);
|
|
}
|
|
break;
|
|
case NumericTypeID::kBF16:
|
|
{
|
|
float tmp;
|
|
ss >> tmp;
|
|
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(tmp);
|
|
}
|
|
break;
|
|
case NumericTypeID::kTF32:
|
|
{
|
|
float tmp;
|
|
ss >> tmp;
|
|
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(tmp);
|
|
}
|
|
break;
|
|
case NumericTypeID::kF32:
|
|
{
|
|
ss >> *reinterpret_cast<float *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kF64:
|
|
{
|
|
ss >> *reinterpret_cast<double *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF16:
|
|
{
|
|
std::complex<float> tmp;
|
|
ss >> tmp;
|
|
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data());
|
|
x->real() = static_cast<half_t>(std::real(tmp));
|
|
x->imag() = static_cast<half_t>(std::imag(tmp));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCBF16:
|
|
{
|
|
std::complex<float> tmp;
|
|
ss >> tmp;
|
|
cutlass::complex<cutlass::bfloat16_t> *x = reinterpret_cast<cutlass::complex<bfloat16_t> *>(bytes.data());
|
|
x->real() = static_cast<bfloat16_t>(std::real(tmp));
|
|
x->imag() = static_cast<bfloat16_t>(std::imag(tmp));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF32:
|
|
{
|
|
ss >> *reinterpret_cast<std::complex<float>*>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kCTF32:
|
|
{
|
|
std::complex<float> tmp;
|
|
ss >> tmp;
|
|
cutlass::complex<cutlass::tfloat32_t> *x = reinterpret_cast<cutlass::complex<tfloat32_t> *>(bytes.data());
|
|
x->real() = static_cast<tfloat32_t>(std::real(tmp));
|
|
x->imag() = static_cast<tfloat32_t>(std::imag(tmp));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF64:
|
|
{
|
|
ss >> *reinterpret_cast<std::complex<double>*>(bytes.data());
|
|
}
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
std::string lexical_cast(int64_t int_value) {
|
|
std::stringstream ss;
|
|
ss << int_value;
|
|
return ss.str();
|
|
}
|
|
|
|
/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid.
|
|
std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
|
|
|
|
int size_bytes = sizeof_bits(type) / 8;
|
|
|
|
if (!size_bytes || size_bytes != bytes.size()) {
|
|
return "<invalid>";
|
|
}
|
|
|
|
bytes.resize(size_bytes, 0);
|
|
|
|
std::stringstream ss;
|
|
|
|
switch (type) {
|
|
case NumericTypeID::kU8:
|
|
{
|
|
ss << *reinterpret_cast<uint8_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kU16:
|
|
{
|
|
ss << *reinterpret_cast<uint16_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kU32:
|
|
{
|
|
ss << *reinterpret_cast<uint32_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kU64:
|
|
{
|
|
ss << *reinterpret_cast<uint64_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS8:
|
|
{
|
|
ss << *reinterpret_cast<int8_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS16:
|
|
{
|
|
ss << *reinterpret_cast<int16_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS32:
|
|
{
|
|
ss << *reinterpret_cast<int32_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kS64:
|
|
{
|
|
ss << *reinterpret_cast<int64_t *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kF16:
|
|
{
|
|
float tmp = *reinterpret_cast<half_t *>(bytes.data());
|
|
ss << tmp;
|
|
}
|
|
break;
|
|
case NumericTypeID::kBF16:
|
|
{
|
|
float tmp = *reinterpret_cast<bfloat16_t *>(bytes.data());
|
|
ss << tmp;
|
|
}
|
|
break;
|
|
case NumericTypeID::kTF32:
|
|
{
|
|
float tmp = *reinterpret_cast<tfloat32_t *>(bytes.data());
|
|
ss << tmp;
|
|
}
|
|
break;
|
|
case NumericTypeID::kF32:
|
|
{
|
|
ss << *reinterpret_cast<float *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kF64:
|
|
{
|
|
ss << *reinterpret_cast<double *>(bytes.data());
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF16:
|
|
{
|
|
cutlass::complex<half_t> const *x =
|
|
reinterpret_cast<cutlass::complex<half_t> const *>(bytes.data());
|
|
|
|
ss << float(x->real());
|
|
|
|
if (x->imag() != cutlass::half_t()) {
|
|
ss << "+i" << float(x->imag());
|
|
}
|
|
}
|
|
break;
|
|
case NumericTypeID::kCBF16:
|
|
{
|
|
cutlass::complex<bfloat16_t> const *x =
|
|
reinterpret_cast<cutlass::complex<bfloat16_t> const *>(bytes.data());
|
|
|
|
ss << float(x->real());
|
|
|
|
if (x->imag() != cutlass::bfloat16_t()) {
|
|
ss << "+i" << float(x->imag());
|
|
}
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF32:
|
|
{
|
|
cutlass::complex<float> const * x = reinterpret_cast<cutlass::complex<float> const *>(bytes.data());
|
|
|
|
ss << x->real();
|
|
|
|
if (x->imag() != float()) {
|
|
ss << "+i" << x->imag();
|
|
}
|
|
}
|
|
break;
|
|
case NumericTypeID::kCTF32:
|
|
{
|
|
cutlass::complex<tfloat32_t> const * x = reinterpret_cast<cutlass::complex<tfloat32_t> const *>(bytes.data());
|
|
|
|
ss << float(x->real());
|
|
|
|
if (x->imag() != tfloat32_t()) {
|
|
ss << "+i" << float(x->imag());
|
|
}
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF64:
|
|
{
|
|
cutlass::complex<double> const * x = reinterpret_cast<cutlass::complex<double> const *>(bytes.data());
|
|
|
|
ss << x->real();
|
|
|
|
if (x->imag() != double()) {
|
|
ss << "+i" << x->imag();
|
|
}
|
|
}
|
|
break;
|
|
default:
|
|
return "<unknown>";
|
|
}
|
|
|
|
return ss.str();
|
|
}
|
|
|
|
/// Casts from a signed int64 to the destination type. Returns true if successful.
|
|
bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src) {
|
|
int size_bytes = sizeof_bits(type) / 8;
|
|
if (!size_bytes) {
|
|
return false;
|
|
}
|
|
|
|
bytes.resize(size_bytes, 0);
|
|
|
|
switch (type) {
|
|
case NumericTypeID::kU8:
|
|
{
|
|
*reinterpret_cast<uint8_t *>(bytes.data()) = static_cast<uint8_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU16:
|
|
{
|
|
*reinterpret_cast<uint16_t *>(bytes.data()) = static_cast<uint16_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU32:
|
|
{
|
|
*reinterpret_cast<uint32_t *>(bytes.data()) = static_cast<uint32_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU64:
|
|
{
|
|
*reinterpret_cast<uint64_t *>(bytes.data()) = static_cast<uint64_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS8:
|
|
{
|
|
*reinterpret_cast<int8_t *>(bytes.data()) = static_cast<int8_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS16:
|
|
{
|
|
*reinterpret_cast<int16_t *>(bytes.data()) = static_cast<int16_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS32:
|
|
{
|
|
*reinterpret_cast<int32_t *>(bytes.data()) = static_cast<int32_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS64:
|
|
{
|
|
*reinterpret_cast<int64_t *>(bytes.data()) = static_cast<int64_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kF16:
|
|
{
|
|
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kBF16:
|
|
{
|
|
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kTF32:
|
|
{
|
|
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kF32:
|
|
{
|
|
*reinterpret_cast<float *>(bytes.data()) = static_cast<float>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kF64:
|
|
{
|
|
*reinterpret_cast<double *>(bytes.data()) = double(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF16:
|
|
{
|
|
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data());
|
|
x->real() = static_cast<half_t>(float(src));
|
|
x->imag() = static_cast<half_t>(float(0));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF32:
|
|
{
|
|
*reinterpret_cast<cutlass::complex<float>*>(bytes.data()) = cutlass::complex<float>(float(src), float(0));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF64:
|
|
{
|
|
*reinterpret_cast<cutlass::complex<double>*>(bytes.data()) = cutlass::complex<double>(double(src), double(0));
|
|
}
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
/// Casts from an unsigned int64 to the destination type. Returns true if successful.
|
|
bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src) {
|
|
int size_bytes = sizeof_bits(type) / 8;
|
|
if (!size_bytes) {
|
|
return false;
|
|
}
|
|
|
|
bytes.resize(size_bytes, 0);
|
|
|
|
switch (type) {
|
|
case NumericTypeID::kU8:
|
|
{
|
|
*reinterpret_cast<uint8_t *>(bytes.data()) = static_cast<uint8_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU16:
|
|
{
|
|
*reinterpret_cast<uint16_t *>(bytes.data()) = static_cast<uint16_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU32:
|
|
{
|
|
*reinterpret_cast<uint32_t *>(bytes.data()) = static_cast<uint32_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU64:
|
|
{
|
|
*reinterpret_cast<uint64_t *>(bytes.data()) = static_cast<uint64_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS8:
|
|
{
|
|
*reinterpret_cast<int8_t *>(bytes.data()) = static_cast<int8_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS16:
|
|
{
|
|
*reinterpret_cast<int16_t *>(bytes.data()) = static_cast<int16_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS32:
|
|
{
|
|
*reinterpret_cast<int32_t *>(bytes.data()) = static_cast<int32_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS64:
|
|
{
|
|
*reinterpret_cast<int64_t *>(bytes.data()) = static_cast<int64_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kF16:
|
|
{
|
|
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kBF16:
|
|
{
|
|
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kTF32:
|
|
{
|
|
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kF32:
|
|
{
|
|
*reinterpret_cast<float *>(bytes.data()) = static_cast<float>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kF64:
|
|
{
|
|
*reinterpret_cast<double *>(bytes.data()) = double(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF16:
|
|
{
|
|
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data());
|
|
x->real() = static_cast<half_t>(float(src));
|
|
x->imag() = static_cast<half_t>(float(0));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF32:
|
|
{
|
|
*reinterpret_cast<std::complex<float>*>(bytes.data()) = std::complex<float>(float(src), float(0));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF64:
|
|
{
|
|
*reinterpret_cast<std::complex<double>*>(bytes.data()) = std::complex<double>(double(src), double(0));
|
|
}
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid.
|
|
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src) {
|
|
|
|
int size_bytes = sizeof_bits(type) / 8;
|
|
if (!size_bytes) {
|
|
return false;
|
|
}
|
|
|
|
bytes.resize(size_bytes, 0);
|
|
|
|
switch (type) {
|
|
case NumericTypeID::kU8:
|
|
{
|
|
*reinterpret_cast<uint8_t *>(bytes.data()) = static_cast<uint8_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU16:
|
|
{
|
|
*reinterpret_cast<uint16_t *>(bytes.data()) = static_cast<uint16_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU32:
|
|
{
|
|
*reinterpret_cast<uint32_t *>(bytes.data()) = static_cast<uint32_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kU64:
|
|
{
|
|
*reinterpret_cast<uint64_t *>(bytes.data()) = static_cast<uint64_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS8:
|
|
{
|
|
*reinterpret_cast<int8_t *>(bytes.data()) = static_cast<int8_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS16:
|
|
{
|
|
*reinterpret_cast<int16_t *>(bytes.data()) = static_cast<int16_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS32:
|
|
{
|
|
*reinterpret_cast<int32_t *>(bytes.data()) = static_cast<int32_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kS64:
|
|
{
|
|
*reinterpret_cast<int64_t *>(bytes.data()) = static_cast<int64_t>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kF16:
|
|
{
|
|
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kBF16:
|
|
{
|
|
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kTF32:
|
|
{
|
|
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(float(src));
|
|
}
|
|
break;
|
|
case NumericTypeID::kF32:
|
|
{
|
|
*reinterpret_cast<float *>(bytes.data()) = static_cast<float>(src);
|
|
}
|
|
break;
|
|
case NumericTypeID::kF64:
|
|
{
|
|
*reinterpret_cast<double *>(bytes.data()) = src;
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF16:
|
|
{
|
|
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data());
|
|
x->real() = static_cast<half_t>(float(src));
|
|
x->imag() = static_cast<half_t>(float(0));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCBF16:
|
|
{
|
|
cutlass::complex<cutlass::bfloat16_t> *x = reinterpret_cast<cutlass::complex<bfloat16_t> *>(bytes.data());
|
|
x->real() = static_cast<bfloat16_t>(bfloat16_t(src));
|
|
x->imag() = static_cast<bfloat16_t>(bfloat16_t(0));
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF32:
|
|
{
|
|
*reinterpret_cast<cutlass::complex<float>*>(bytes.data()) = cutlass::complex<float>(float(src), float());
|
|
}
|
|
break;
|
|
case NumericTypeID::kCTF32:
|
|
{
|
|
*reinterpret_cast<cutlass::complex<tfloat32_t>*>(bytes.data()) = cutlass::complex<tfloat32_t>(tfloat32_t(src), tfloat32_t());
|
|
}
|
|
break;
|
|
case NumericTypeID::kCF64:
|
|
{
|
|
*reinterpret_cast<cutlass::complex<double>*>(bytes.data()) = cutlass::complex<double>(src, double());
|
|
}
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace library
|
|
} // namespace cutlass
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|