CUTLASS 2.3 adds GEMMs targeting Sparse Tensor Cores on the NVIDIA Ampere Architecture, fast SGEMM, and small matrix classes, bug fixes, and performance enhancements.
1058 lines
32 KiB
C++
1058 lines
32 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
* provided that the following conditions are met:
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
* conditions and the following disclaimer.
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
* provided with the distribution.
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
* permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/* \file
|
|
\brief
|
|
*/
|
|
|
|
#include <string>
|
|
#include <stdexcept>
|
|
#include <sstream>
|
|
|
|
#include "cutlass/library/util.h"
|
|
|
|
#include "problem_space.h"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace profiler {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
static T lexical_cast(std::string const &str) {
|
|
std::stringstream ss;
|
|
T value;
|
|
|
|
ss << str;
|
|
ss >> value;
|
|
|
|
return value;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
std::ostream & KernelArgument::ValueIterator::print(std::ostream &out) const {
|
|
out << "[" << (void *)this << " " << argument->qualified_name() << "] ";
|
|
if (this->null_argument) {
|
|
out << "<null>";
|
|
}
|
|
else {
|
|
out << "<not null>";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
KernelArgument::~KernelArgument() {
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
ScalarArgument::ScalarValue::ScalarValue(
|
|
std::string const &value_,
|
|
ScalarArgument const *argument_,
|
|
bool not_null_
|
|
):
|
|
KernelArgument::Value(argument_, not_null_),
|
|
value(value_) {
|
|
|
|
}
|
|
|
|
std::ostream &ScalarArgument::ScalarValue::print(std::ostream &out) const {
|
|
out << argument->qualified_name() << ": ";
|
|
if (not_null) {
|
|
out << value;
|
|
}
|
|
else {
|
|
out << "<null>";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
ScalarArgument::ScalarValueIterator::ScalarValueIterator(
|
|
ScalarArgument const *argument_
|
|
):
|
|
KernelArgument::ValueIterator(argument_) {
|
|
|
|
if (argument_) {
|
|
value_it = argument_->values.begin();
|
|
}
|
|
}
|
|
|
|
void ScalarArgument::ScalarValueIterator::operator++() {
|
|
if (this->null_argument) {
|
|
this->null_argument = false;
|
|
}
|
|
else {
|
|
++value_it;
|
|
}
|
|
}
|
|
|
|
bool ScalarArgument::ScalarValueIterator::operator==(ValueIterator const &it) const {
|
|
if (it.type() != ArgumentTypeID::kScalar) {
|
|
throw std::runtime_error("Cannot compare ScalarValueIterator with iterator of different type");
|
|
}
|
|
auto const & scalar_it = static_cast<ScalarValueIterator const &>(it);
|
|
return value_it == scalar_it.value_it;
|
|
}
|
|
|
|
/// Gets the value pointed to
|
|
std::unique_ptr<KernelArgument::Value> ScalarArgument::ScalarValueIterator::at() const {
|
|
if (this->null_argument) {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new ScalarArgument::ScalarValue(
|
|
std::string(),
|
|
static_cast<ScalarArgument const *>(argument),
|
|
false));
|
|
}
|
|
else {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new ScalarArgument::ScalarValue(
|
|
*value_it,
|
|
static_cast<ScalarArgument const *>(argument)));
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> ScalarArgument::begin() const {
|
|
return std::unique_ptr<KernelArgument::ValueIterator>(new ScalarValueIterator(this));
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> ScalarArgument::end() const {
|
|
ScalarValueIterator *it = new ScalarValueIterator(this);
|
|
it->value_it = this->values.end();
|
|
it->null_argument = false;
|
|
return std::unique_ptr<ValueIterator>(it);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
IntegerArgument::IntegerValue::IntegerValue(
|
|
int64_t value_,
|
|
IntegerArgument const *argument_,
|
|
bool not_null_
|
|
): KernelArgument::Value(argument_, not_null_), value(value_) {
|
|
|
|
}
|
|
|
|
|
|
/// Pretty printer for debugging
|
|
std::ostream &IntegerArgument::IntegerValue::print(std::ostream &out) const {
|
|
out << argument->qualified_name() << ": ";
|
|
if (not_null) {
|
|
out << value;
|
|
}
|
|
else {
|
|
out << "<null>";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
IntegerArgument::IntegerValueIterator::IntegerValueIterator(IntegerArgument const *argument_):
|
|
KernelArgument::ValueIterator(argument_) {
|
|
|
|
if (argument_) {
|
|
range_it = argument_->ranges.begin();
|
|
if (range_it != argument_->ranges.end()) {
|
|
value_it = range_it->begin();
|
|
}
|
|
}
|
|
}
|
|
|
|
void IntegerArgument::IntegerValueIterator::operator++() {
|
|
|
|
if (this->null_argument) {
|
|
this->null_argument = false;
|
|
}
|
|
else {
|
|
++value_it;
|
|
if (value_it == range_it->end()) {
|
|
++range_it;
|
|
if (range_it != static_cast<IntegerArgument const *>(argument)->ranges.end()) {
|
|
value_it = range_it->begin();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
bool IntegerArgument::IntegerValueIterator::operator==(ValueIterator const &it) const {
|
|
if (it.type() != ArgumentTypeID::kInteger) {
|
|
throw std::runtime_error("Cannot compare IntegerValueIterator with iterator of different type");
|
|
}
|
|
|
|
auto const & integer_iterator = static_cast<IntegerValueIterator const &>(it);
|
|
|
|
if (this->null_argument) {
|
|
return it.null_argument;
|
|
}
|
|
else {
|
|
if (range_it != integer_iterator.range_it) {
|
|
return false;
|
|
}
|
|
if (range_it == static_cast<IntegerArgument const *>(argument)->ranges.end() &&
|
|
range_it == integer_iterator.range_it) {
|
|
return true;
|
|
}
|
|
return value_it == integer_iterator.value_it;
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::Value> IntegerArgument::IntegerValueIterator::at() const {
|
|
if (this->null_argument) {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new IntegerArgument::IntegerValue(
|
|
0, static_cast<IntegerArgument const *>(argument), false));
|
|
}
|
|
else {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new IntegerArgument::IntegerValue(
|
|
*value_it, static_cast<IntegerArgument const *>(argument)));
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> IntegerArgument::begin() const {
|
|
return std::unique_ptr<KernelArgument::ValueIterator>(new IntegerValueIterator(this));
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> IntegerArgument::end() const {
|
|
IntegerValueIterator *it = new IntegerValueIterator(this);
|
|
it->range_it = this->ranges.end();
|
|
it->null_argument = false;
|
|
return std::unique_ptr<ValueIterator>(it);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TensorArgument::TensorValue::TensorValue(
|
|
TensorDescription const &desc_,
|
|
TensorArgument const *argument_,
|
|
bool not_null_
|
|
):
|
|
KernelArgument::Value(argument_, not_null_),
|
|
desc(desc_) {
|
|
|
|
}
|
|
|
|
/// Pretty printer for debugging
|
|
std::ostream &TensorArgument::TensorValue::print(std::ostream &out) const {
|
|
out << argument->qualified_name() << ": " << to_string(desc.element) << ": " << to_string(desc.layout);
|
|
return out;
|
|
}
|
|
|
|
TensorArgument::TensorValueIterator::TensorValueIterator(
|
|
TensorArgument const *argument_
|
|
):
|
|
KernelArgument::ValueIterator(argument_) {
|
|
|
|
if (argument_) {
|
|
value_it = argument_->values.begin();
|
|
}
|
|
}
|
|
|
|
void TensorArgument::TensorValueIterator::operator++() {
|
|
if (this->null_argument) {
|
|
this->null_argument = false;
|
|
}
|
|
else {
|
|
++value_it;
|
|
}
|
|
}
|
|
|
|
bool TensorArgument::TensorValueIterator::operator==(ValueIterator const &it) const {
|
|
if (it.type() != ArgumentTypeID::kTensor) {
|
|
throw std::runtime_error("Cannot compare TensorValueIterator with iterator of different type");
|
|
}
|
|
auto const & tensor_it = static_cast<TensorValueIterator const &>(it);
|
|
return value_it == tensor_it.value_it;
|
|
}
|
|
|
|
/// Gets the value pointed to
|
|
std::unique_ptr<KernelArgument::Value> TensorArgument::TensorValueIterator::at() const {
|
|
|
|
if (this->null_argument) {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new TensorArgument::TensorValue(
|
|
TensorDescription(), static_cast<TensorArgument const *>(argument), false));
|
|
}
|
|
else {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new TensorArgument::TensorValue(
|
|
*value_it, static_cast<TensorArgument const *>(argument)));
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> TensorArgument::begin() const {
|
|
return std::unique_ptr<KernelArgument::ValueIterator>(new TensorValueIterator(this));
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> TensorArgument::end() const {
|
|
TensorValueIterator *it = new TensorValueIterator(this);
|
|
it->value_it = this->values.end();
|
|
it->null_argument = false;
|
|
return std::unique_ptr<ValueIterator>(it);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
EnumeratedTypeArgument::EnumeratedTypeValue::EnumeratedTypeValue(
|
|
std::string const & element_,
|
|
EnumeratedTypeArgument const *argument_,
|
|
bool not_null_
|
|
):
|
|
KernelArgument::Value(argument_, not_null_),
|
|
element(element_) {
|
|
|
|
}
|
|
|
|
/// Pretty printer for debugging
|
|
std::ostream &EnumeratedTypeArgument::EnumeratedTypeValue::print(std::ostream &out) const {
|
|
out << argument->qualified_name() << ": " << element;
|
|
return out;
|
|
}
|
|
|
|
EnumeratedTypeArgument::EnumeratedTypeValueIterator::EnumeratedTypeValueIterator(
|
|
EnumeratedTypeArgument const *argument_
|
|
):
|
|
KernelArgument::ValueIterator(argument_) {
|
|
|
|
if (argument_) {
|
|
value_it = argument_->values.begin();
|
|
}
|
|
}
|
|
|
|
void EnumeratedTypeArgument::EnumeratedTypeValueIterator::operator++() {
|
|
if (this->null_argument) {
|
|
this->null_argument = false;
|
|
}
|
|
else {
|
|
++value_it;
|
|
}
|
|
}
|
|
|
|
bool EnumeratedTypeArgument::EnumeratedTypeValueIterator::operator==(ValueIterator const &it) const {
|
|
|
|
if (it.type() != ArgumentTypeID::kEnumerated) {
|
|
throw std::runtime_error("Cannot compare EnumeratedTypeValueIterator with iterator of different type");
|
|
}
|
|
|
|
auto const & enumerated_type_it = static_cast<EnumeratedTypeValueIterator const &>(it);
|
|
return value_it == enumerated_type_it.value_it;
|
|
}
|
|
|
|
/// Gets the value pointed to
|
|
std::unique_ptr<KernelArgument::Value> EnumeratedTypeArgument::EnumeratedTypeValueIterator::at() const {
|
|
|
|
if (this->null_argument) {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new EnumeratedTypeValue(
|
|
std::string(), static_cast<EnumeratedTypeArgument const *>(argument), false));
|
|
}
|
|
else {
|
|
return std::unique_ptr<KernelArgument::Value>(
|
|
new EnumeratedTypeValue(
|
|
*value_it, static_cast<EnumeratedTypeArgument const *>(argument)));
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> EnumeratedTypeArgument::begin() const {
|
|
return std::unique_ptr<KernelArgument::ValueIterator>(new EnumeratedTypeValueIterator(this));
|
|
}
|
|
|
|
std::unique_ptr<KernelArgument::ValueIterator> EnumeratedTypeArgument::end() const {
|
|
EnumeratedTypeValueIterator *it = new EnumeratedTypeValueIterator(this);
|
|
it->value_it = this->values.end();
|
|
it->null_argument = false;
|
|
return std::unique_ptr<ValueIterator>(it);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
ProblemSpace::Iterator::Iterator() {
|
|
|
|
}
|
|
|
|
ProblemSpace::Iterator::Iterator(ProblemSpace const &problem_space) {
|
|
for (auto const & arg_ptr : problem_space.arguments) {
|
|
construct_(arg_ptr.get());
|
|
}
|
|
}
|
|
|
|
ProblemSpace::Iterator::Iterator(Iterator && it) {
|
|
iterators = std::move(it.iterators);
|
|
}
|
|
|
|
/// Helper for recursively constructing iterators
|
|
void ProblemSpace::Iterator::construct_(KernelArgument const *argument) {
|
|
iterators.emplace_back(argument->begin());
|
|
}
|
|
|
|
/// Given a set of ranges, iterate over the points within their Cartesian product. No big deal.
|
|
void ProblemSpace::Iterator::operator++() {
|
|
|
|
// Define a pair of iterator into the vector of iterators.
|
|
IteratorVector::iterator iterator_it = iterators.begin();
|
|
IteratorVector::iterator next_iterator = iterator_it;
|
|
|
|
// Advance the first argument.
|
|
++(**iterator_it);
|
|
|
|
// Maintain a pair of iterators over consecutive arguments.
|
|
++next_iterator;
|
|
|
|
// Carry logic
|
|
while (next_iterator != iterators.end() &&
|
|
**iterator_it == *((*iterator_it)->argument->end())) { // Did an iterator reach the end of its range?
|
|
|
|
(*iterator_it) = (*iterator_it)->argument->begin(); // Reset that iterator,
|
|
|
|
++(**next_iterator); // and increment the next argument's iterator.
|
|
|
|
iterator_it = next_iterator; // Advance to the next argument
|
|
++next_iterator;
|
|
}
|
|
}
|
|
|
|
/// Moves iterator to end
|
|
void ProblemSpace::Iterator::move_to_end() {
|
|
if (!iterators.empty()) {
|
|
std::unique_ptr<KernelArgument::ValueIterator> new_iter = iterators.back()->argument->end();
|
|
std::swap(iterators.back(), new_iter);
|
|
}
|
|
}
|
|
|
|
ProblemSpace::Problem ProblemSpace::Iterator::at() const {
|
|
Problem problem;
|
|
|
|
for (std::unique_ptr<KernelArgument::ValueIterator> const & it : iterators) {
|
|
problem.emplace_back(it->at());
|
|
}
|
|
|
|
return problem;
|
|
}
|
|
|
|
/// Equality operator
|
|
bool ProblemSpace::Iterator::operator==(Iterator const &it) const {
|
|
|
|
// This would be an opportunity for auto, but explicitly denoting references to
|
|
// owning smart pointers to dynamic polymorphic objects seems like a kindness to the reader.
|
|
IteratorVector::const_iterator first_it = iterators.begin();
|
|
IteratorVector::const_iterator second_it = it.iterators.begin();
|
|
|
|
int idx = 0;
|
|
for (; first_it != iterators.end(); ++first_it, ++second_it, ++idx) {
|
|
|
|
KernelArgument::ValueIterator const *my_it = first_it->get();
|
|
KernelArgument::ValueIterator const *their_it = second_it->get();
|
|
|
|
if (*my_it != *their_it) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
std::ostream &ProblemSpace::Iterator::print(std::ostream &out) const {
|
|
|
|
for (std::unique_ptr<KernelArgument::ValueIterator> const & iter_ptr : iterators) {
|
|
out << " [iter " << (iter_ptr->null_argument ? "null" : "<not null>")
|
|
<< ", type: " << to_string(iter_ptr->argument->description->type) << "]" << std::endl;
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
ProblemSpace::ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline) {
|
|
|
|
// Clone the arguments
|
|
for (ArgumentDescription const & arg_desc : schema) {
|
|
clone_(arguments, &arg_desc);
|
|
}
|
|
|
|
// Parse values from the command line
|
|
for (auto & arg : arguments) {
|
|
parse_(arg.get(), cmdline);
|
|
}
|
|
}
|
|
|
|
|
|
/// Returns the index of an argument by name
|
|
size_t ProblemSpace::argument_index(char const *name) const {
|
|
return argument_index_map.at(name);
|
|
}
|
|
|
|
/// Helper for recursively cloning
|
|
void ProblemSpace::clone_(
|
|
KernelArgumentVector &kernel_args,
|
|
ArgumentDescription const *arg_desc) {
|
|
|
|
KernelArgument *kernel_arg = nullptr;
|
|
|
|
switch (arg_desc->type) {
|
|
case ArgumentTypeID::kScalar:
|
|
kernel_arg = new ScalarArgument(arg_desc);
|
|
break;
|
|
case ArgumentTypeID::kInteger:
|
|
kernel_arg = new IntegerArgument(arg_desc);
|
|
break;
|
|
case ArgumentTypeID::kTensor:
|
|
kernel_arg = new TensorArgument(arg_desc);
|
|
break;
|
|
case ArgumentTypeID::kStructure:
|
|
{
|
|
throw std::runtime_error("ArgumentTypeID::kStructure not supported");
|
|
}
|
|
break;
|
|
case ArgumentTypeID::kEnumerated:
|
|
kernel_arg = new EnumeratedTypeArgument(arg_desc);
|
|
break;
|
|
|
|
default: break;
|
|
}
|
|
|
|
if (kernel_arg) {
|
|
size_t idx = kernel_args.size();
|
|
for (auto const &alias : arg_desc->aliases) {
|
|
argument_index_map.insert(std::make_pair(alias, idx));
|
|
}
|
|
kernel_args.emplace_back(kernel_arg);
|
|
}
|
|
}
|
|
|
|
/// Parses a command line
|
|
void ProblemSpace::parse_(KernelArgument *arg, CommandLine const &cmdline) {
|
|
|
|
switch (arg->description->type) {
|
|
case ArgumentTypeID::kScalar:
|
|
{
|
|
auto * scalar = static_cast<ScalarArgument *>(arg);
|
|
|
|
for (auto const &alias : arg->description->aliases) {
|
|
if (cmdline.check_cmd_line_flag(alias.c_str())) {
|
|
|
|
std::vector<std::vector<std::string>> tokens;
|
|
cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens);
|
|
|
|
for (auto const & vec : tokens) {
|
|
if (!vec.empty()) {
|
|
scalar->values.push_back(vec.front());
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
case ArgumentTypeID::kInteger:
|
|
{
|
|
auto *integer = static_cast<IntegerArgument *>(arg);
|
|
|
|
for (auto const &alias : arg->description->aliases) {
|
|
if (cmdline.check_cmd_line_flag(alias.c_str())) {
|
|
|
|
std::vector<std::vector<std::string> > tokens;
|
|
cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens);
|
|
|
|
for (auto &range_tokens : tokens) {
|
|
|
|
if (!range_tokens.empty()) {
|
|
|
|
Range range;
|
|
|
|
if (range_tokens.front() == "rand") {
|
|
range.mode = Range::Mode::kRandom;
|
|
}
|
|
else if (range_tokens.front() == "randlg2") {
|
|
range.mode = Range::Mode::kRandomLog2;
|
|
}
|
|
|
|
switch (range.mode) {
|
|
case Range::Mode::kSequence:
|
|
{
|
|
range.first = lexical_cast<int64_t>(range_tokens.front());
|
|
|
|
if (range_tokens.size() > 1) {
|
|
range.last = lexical_cast<int64_t>(range_tokens.at(1));
|
|
}
|
|
else {
|
|
range.last = range.first;
|
|
}
|
|
|
|
if (range_tokens.size() > 2) {
|
|
range.increment = lexical_cast<int64_t>(range_tokens.at(2));
|
|
}
|
|
else {
|
|
range.increment = 1;
|
|
}
|
|
}
|
|
break;
|
|
case Range::Mode::kRandom: // fall-through
|
|
case Range::Mode::kRandomLog2:
|
|
{
|
|
if (range_tokens.size() < 4) {
|
|
throw std::runtime_error(
|
|
"Range of mode 'rand' must have four tokens showing "
|
|
"the minimum, maximum, and number of iterations. For example, "
|
|
"rand:16:128:1000");
|
|
}
|
|
|
|
range.minimum = lexical_cast<int64_t>(range_tokens.at(1));
|
|
range.maximum = lexical_cast<int64_t>(range_tokens.at(2));
|
|
range.first = 1;
|
|
range.last = lexical_cast<int64_t>(range_tokens.at(3));
|
|
range.increment = 1;
|
|
|
|
if (range_tokens.size() > 4) {
|
|
range.divisible = lexical_cast<int64_t>(range_tokens.at(4));
|
|
}
|
|
}
|
|
break;
|
|
default:
|
|
throw std::runtime_error("Unsupported range mode.");
|
|
break;
|
|
}
|
|
|
|
integer->ranges.push_back(range);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
case ArgumentTypeID::kTensor:
|
|
{
|
|
auto *tensor = static_cast<TensorArgument *>(arg);
|
|
|
|
for (auto const &alias : arg->description->aliases) {
|
|
if (cmdline.check_cmd_line_flag(alias.c_str())) {
|
|
|
|
std::vector<std::vector<std::string>> tokens;
|
|
|
|
cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens);
|
|
|
|
for (auto const & tensor_tokens : tokens) {
|
|
if (!tensor_tokens.empty()) {
|
|
TensorArgument::TensorDescription tensor_desc;
|
|
|
|
tensor_desc.element = cutlass::library::from_string<library::NumericTypeID>(tensor_tokens.front());
|
|
|
|
// Layout
|
|
if (tensor_tokens.size() > 1) {
|
|
tensor_desc.layout = cutlass::library::from_string<library::LayoutTypeID>(tensor_tokens.at(1));
|
|
}
|
|
|
|
// Stride
|
|
for (size_t i = 2; i < tensor_tokens.size(); ++i) {
|
|
tensor_desc.stride.push_back(lexical_cast<int>(tensor_tokens.at(i)));
|
|
}
|
|
|
|
tensor->values.push_back(tensor_desc);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
case ArgumentTypeID::kStructure:
|
|
{
|
|
throw std::runtime_error("Structure arguments not supported");
|
|
}
|
|
break;
|
|
case ArgumentTypeID::kEnumerated:
|
|
{
|
|
auto *enumerated_type = static_cast<EnumeratedTypeArgument *>(arg);
|
|
|
|
for (auto const &alias : arg->description->aliases) {
|
|
if (cmdline.check_cmd_line_flag(alias.c_str())) {
|
|
|
|
std::vector<std::string> tokens;
|
|
cmdline.get_cmd_line_arguments(alias.c_str(), tokens);
|
|
|
|
for (auto const & token : tokens) {
|
|
enumerated_type->values.push_back(token);
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
ProblemSpace::Iterator ProblemSpace::begin() const {
|
|
return ProblemSpace::Iterator(*this);
|
|
}
|
|
|
|
ProblemSpace::Iterator ProblemSpace::end() const {
|
|
ProblemSpace::Iterator it(*this);
|
|
it.move_to_end();
|
|
return it;
|
|
}
|
|
|
|
/// Gets all argument names as an ordered vector
|
|
std::vector<std::string> ProblemSpace::argument_names() const {
|
|
|
|
Problem problem = this->begin().at();
|
|
|
|
std::vector<std::string> names;
|
|
names.reserve(problem.size());
|
|
|
|
for (auto const & arg : problem) {
|
|
names.push_back(arg->argument->description->aliases.front());
|
|
}
|
|
|
|
return names;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr) {
|
|
if (value_ptr->not_null) {
|
|
if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) {
|
|
int_value = static_cast<IntegerArgument::IntegerValue const *>(value_ptr)->value;
|
|
}
|
|
else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) {
|
|
std::stringstream ss;
|
|
ss << static_cast<ScalarArgument::ScalarValue const *>(value_ptr)->value;
|
|
ss >> int_value;
|
|
}
|
|
else {
|
|
throw std::runtime_error(
|
|
"arg_as_int64_t() - illegal cast. Problem space argument must be integer or scalar");
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr) {
|
|
int64_t value64;
|
|
bool obtained = arg_as_int(value64, value_ptr);
|
|
if (obtained) {
|
|
int_value = int(value64);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Lexically casts an argument to an int
|
|
bool arg_as_int(
|
|
int &int_value,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
return arg_as_int(int_value, value_ptr);
|
|
}
|
|
|
|
/// Lexically casts an argument to an int64
|
|
bool arg_as_int(
|
|
int64_t &int_value,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
return arg_as_int(int_value, value_ptr);
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_NumericTypeID(
|
|
library::NumericTypeID &numeric_type,
|
|
KernelArgument::Value const *value_ptr) {
|
|
|
|
if (value_ptr->not_null) {
|
|
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) {
|
|
|
|
numeric_type = library::from_string<library::NumericTypeID>(
|
|
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element);
|
|
|
|
if (numeric_type == library::NumericTypeID::kInvalid) {
|
|
throw std::runtime_error(
|
|
"arg_as_NumericTypeID() - illegal cast.");
|
|
}
|
|
}
|
|
else {
|
|
|
|
throw std::runtime_error(
|
|
"arg_as_NumericTypeID() - illegal cast.");
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_NumericTypeID(
|
|
library::NumericTypeID &numeric_type,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
return arg_as_NumericTypeID(numeric_type, value_ptr);
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_LayoutTypeID(
|
|
library::LayoutTypeID &layout_type,
|
|
KernelArgument::Value const *value_ptr) {
|
|
|
|
if (value_ptr->not_null) {
|
|
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) {
|
|
|
|
layout_type = library::from_string<library::LayoutTypeID>(
|
|
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element);
|
|
|
|
if (layout_type == library::LayoutTypeID::kInvalid) {
|
|
throw std::runtime_error(
|
|
"arg_as_LayoutTypeID() - illegal cast.");
|
|
}
|
|
}
|
|
else {
|
|
|
|
throw std::runtime_error(
|
|
"arg_as_LayoutTypeID() - illegal cast.");
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_LayoutTypeID(
|
|
library::LayoutTypeID &layout_type,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
return arg_as_LayoutTypeID(layout_type, value_ptr);
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_OpcodeClassID(
|
|
library::OpcodeClassID &opcode_class,
|
|
KernelArgument::Value const *value_ptr) {
|
|
|
|
if (value_ptr->not_null) {
|
|
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) {
|
|
|
|
opcode_class = library::from_string<library::OpcodeClassID>(
|
|
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element);
|
|
|
|
if (opcode_class == library::OpcodeClassID::kInvalid) {
|
|
throw std::runtime_error(
|
|
"arg_as_OpcodeClassID() - illegal cast.");
|
|
}
|
|
}
|
|
else {
|
|
|
|
throw std::runtime_error(
|
|
"arg_as_OpcodeClassID() - illegal cast.");
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_OpcodeClassID(
|
|
library::OpcodeClassID &opcode_class,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
return arg_as_OpcodeClassID(opcode_class, value_ptr);
|
|
}
|
|
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_SplitKModeID(
|
|
library::SplitKMode &split_k_mode,
|
|
KernelArgument::Value const *value_ptr) {
|
|
|
|
if (value_ptr->not_null) {
|
|
if (value_ptr->argument->description->type == ArgumentTypeID::kEnumerated) {
|
|
|
|
split_k_mode = library::from_string<library::SplitKMode>(
|
|
static_cast<EnumeratedTypeArgument::EnumeratedTypeValue const *>(value_ptr)->element);
|
|
|
|
if (split_k_mode == library::SplitKMode::kInvalid) {
|
|
throw std::runtime_error(
|
|
"arg_as_SplitKModeID() - illegal cast.");
|
|
}
|
|
}
|
|
else {
|
|
|
|
throw std::runtime_error(
|
|
"arg_as_SplitKModeID() - illegal cast.");
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Lexically casts an argument to an int64 if it is defined. Returns true if not null.
|
|
bool arg_as_SplitKModeID(
|
|
library::SplitKMode &split_k_mode,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
return arg_as_SplitKModeID(split_k_mode, value_ptr);
|
|
}
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null.
|
|
bool arg_as_scalar(
|
|
std::vector<uint8_t> &bytes,
|
|
library::NumericTypeID numeric_type,
|
|
KernelArgument::Value const *value_ptr) {
|
|
|
|
if (value_ptr->not_null) {
|
|
if (value_ptr->argument->description->type == ArgumentTypeID::kInteger) {
|
|
int64_t int_value = static_cast<IntegerArgument::IntegerValue const *>(value_ptr)->value;
|
|
|
|
// TODO - convert int64_t => destination type
|
|
}
|
|
else if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) {
|
|
std::string const &str_value = static_cast<ScalarArgument::ScalarValue const *>(value_ptr)->value;
|
|
|
|
return lexical_cast(bytes, numeric_type, str_value);
|
|
}
|
|
else {
|
|
throw std::runtime_error(
|
|
"arg_as_int() - illegal cast. Problem space argument must be integer or scalar");
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Lexically casts an argument to a given type and returns a byte array
|
|
bool arg_as_scalar(
|
|
std::vector<uint8_t> &bytes,
|
|
library::NumericTypeID numeric_type,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
return arg_as_scalar(bytes, numeric_type, value_ptr);
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Returns true if a tensor description satisfies a `tensor` value
|
|
bool tensor_description_satisfies(
|
|
library::TensorDescription const &tensor_desc,
|
|
TensorArgument::TensorValue const *value_ptr) {
|
|
|
|
if (value_ptr->not_null) {
|
|
if (value_ptr->desc.element != library::NumericTypeID::kUnknown &&
|
|
value_ptr->desc.element != tensor_desc.element) {
|
|
|
|
return false;
|
|
}
|
|
|
|
if (value_ptr->desc.layout != library::LayoutTypeID::kUnknown &&
|
|
value_ptr->desc.layout != tensor_desc.layout) {
|
|
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Returns true if a tensor description satisfies a `tensor` value
|
|
bool tensor_description_satisfies(
|
|
library::TensorDescription const &tensor_desc,
|
|
char const *name,
|
|
ProblemSpace const &problem_space,
|
|
ProblemSpace::Problem const &problem) {
|
|
|
|
size_t idx = problem_space.argument_index(name);
|
|
KernelArgument::Value const *value_ptr = problem.at(idx).get();
|
|
|
|
if (value_ptr->argument->description->type == ArgumentTypeID::kTensor) {
|
|
return tensor_description_satisfies(
|
|
tensor_desc,
|
|
static_cast<TensorArgument::TensorValue const *>(value_ptr));
|
|
}
|
|
else {
|
|
throw std::runtime_error("Kernel argument mismatch");
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
} // namespace profiler
|
|
} // namespace cutlass
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|