282 lines
8.4 KiB
C++
282 lines
8.4 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are not permitted.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
|
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
|
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
/**
|
|
* \file
|
|
* Utility for parsing command line arguments
|
|
*/
|
|
|
|
#include <iostream>
|
|
#include <limits>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
namespace cutlass {
|
|
|
|
/******************************************************************************
|
|
* command_line
|
|
******************************************************************************/
|
|
|
|
/**
|
|
* Utility for parsing command line arguments
|
|
*/
|
|
struct CommandLine {
|
|
std::vector<std::string> keys;
|
|
std::vector<std::string> values;
|
|
std::vector<std::string> args;
|
|
|
|
/**
|
|
* Constructor
|
|
*/
|
|
CommandLine(int argc, const char** argv) : keys(10), values(10) {
|
|
using namespace std;
|
|
|
|
for (int i = 1; i < argc; i++) {
|
|
string arg = argv[i];
|
|
|
|
if ((arg[0] != '-') || (arg[1] != '-')) {
|
|
args.push_back(arg);
|
|
continue;
|
|
}
|
|
|
|
string::size_type pos;
|
|
string key, val;
|
|
if ((pos = arg.find('=')) == string::npos) {
|
|
key = string(arg, 2, arg.length() - 2);
|
|
val = "";
|
|
} else {
|
|
key = string(arg, 2, pos - 2);
|
|
val = string(arg, pos + 1, arg.length() - 1);
|
|
}
|
|
|
|
keys.push_back(key);
|
|
values.push_back(val);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Checks whether a flag "--<flag>" is present in the commandline
|
|
*/
|
|
bool check_cmd_line_flag(const char* arg_name) const {
|
|
using namespace std;
|
|
|
|
for (int i = 0; i < int(keys.size()); ++i) {
|
|
if (keys[i] == string(arg_name)) return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/**
|
|
* Returns number of naked (non-flag and non-key-value) commandline parameters
|
|
*/
|
|
template <typename value_t>
|
|
int num_naked_args() const {
|
|
return args.size();
|
|
}
|
|
|
|
/**
|
|
* Returns the commandline parameter for a given index (not including flags)
|
|
*/
|
|
template <typename value_t>
|
|
void get_cmd_line_argument(int index, value_t& val) const {
|
|
using namespace std;
|
|
if (index < args.size()) {
|
|
istringstream str_stream(args[index]);
|
|
str_stream >> val;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns the boolean value specified for a given commandline parameter --<flag>=<bool>
|
|
*/
|
|
void get_cmd_line_argument(const char* arg_name, bool& val, bool _default = true) const {
|
|
val = _default;
|
|
if (check_cmd_line_flag(arg_name)) {
|
|
std::string value;
|
|
get_cmd_line_argument(arg_name, value);
|
|
|
|
val = !(value == "0" || value == "false");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns the value specified for a given commandline parameter --<flag>=<value>
|
|
*/
|
|
template <typename value_t>
|
|
void get_cmd_line_argument(const char* arg_name,
|
|
value_t& val,
|
|
value_t const& _default = value_t()) const {
|
|
using namespace std;
|
|
|
|
val = _default;
|
|
|
|
for (int i = 0; i < int(keys.size()); ++i) {
|
|
if (keys[i] == string(arg_name)) {
|
|
istringstream str_stream(values[i]);
|
|
str_stream >> val;
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
|
|
*/
|
|
template <typename value_t>
|
|
void get_cmd_line_arguments(const char* arg_name,
|
|
std::vector<value_t>& vals,
|
|
char sep = ',') const {
|
|
using namespace std;
|
|
|
|
if (check_cmd_line_flag(arg_name)) {
|
|
// Clear any default values
|
|
vals.clear();
|
|
|
|
// Recover from multi-value string
|
|
for (int i = 0; i < keys.size(); ++i) {
|
|
if (keys[i] == string(arg_name)) {
|
|
string val_string(values[i]);
|
|
seperate_string(val_string, vals, sep);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns the values specified for a given commandline parameter
|
|
* --<flag>=<value>,<value_start:value_end>*
|
|
*/
|
|
void get_cmd_line_argument_pairs(const char* arg_name,
|
|
std::vector<std::pair<std::string, std::string> >& tokens,
|
|
char delim = ',',
|
|
char sep = ':') const {
|
|
if (check_cmd_line_flag(arg_name)) {
|
|
std::string value;
|
|
get_cmd_line_argument(arg_name, value);
|
|
|
|
tokenize(tokens, value, delim, sep);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns a list of ranges specified for a given commandline parameter
|
|
* --<flag>=<key:value>,<key:value>*
|
|
*/
|
|
void get_cmd_line_argument_ranges(const char* arg_name,
|
|
std::vector<std::vector<std::string> >& vals,
|
|
char delim = ',',
|
|
char sep = ':') const {
|
|
std::vector<std::string> ranges;
|
|
get_cmd_line_arguments(arg_name, ranges, delim);
|
|
|
|
for (std::vector<std::string>::const_iterator range = ranges.begin();
|
|
range != ranges.end(); ++range) {
|
|
|
|
std::vector<std::string> range_vals;
|
|
seperate_string(*range, range_vals, sep);
|
|
vals.push_back(range_vals);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* The number of pairs parsed
|
|
*/
|
|
int parsed_argc() const { return (int)keys.size(); }
|
|
|
|
//-------------------------------------------------------------------------
|
|
// Utility functions
|
|
//-------------------------------------------------------------------------
|
|
|
|
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
|
static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
|
|
std::string const& str,
|
|
char delim = ',',
|
|
char sep = ':') {
|
|
// Home-built to avoid Boost dependency
|
|
size_t s_idx = 0;
|
|
size_t d_idx = std::string::npos;
|
|
while (s_idx < str.size()) {
|
|
d_idx = str.find_first_of(delim, s_idx);
|
|
|
|
size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
|
|
size_t sep_idx = str.find_first_of(sep, s_idx);
|
|
size_t offset = 1;
|
|
if (sep_idx == std::string::npos || sep_idx >= end_idx) {
|
|
sep_idx = end_idx;
|
|
offset = 0;
|
|
}
|
|
|
|
std::pair<std::string, std::string> item(
|
|
str.substr(s_idx, sep_idx - s_idx),
|
|
str.substr(sep_idx + offset, end_idx - sep_idx - offset));
|
|
|
|
tokens.push_back(item);
|
|
s_idx = end_idx + 1;
|
|
}
|
|
}
|
|
|
|
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
|
static void tokenize(std::vector<std::string>& tokens,
|
|
std::string const& str,
|
|
char delim = ',',
|
|
char sep = ':') {
|
|
typedef std::vector<std::pair<std::string, std::string> > TokenVector;
|
|
typedef TokenVector::const_iterator token_iterator;
|
|
|
|
std::vector<std::pair<std::string, std::string> > token_pairs;
|
|
tokenize(token_pairs, str, delim, sep);
|
|
for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
|
|
tokens.push_back(tok->first);
|
|
}
|
|
}
|
|
|
|
template <typename value_t>
|
|
static void seperate_string(std::string const& str,
|
|
std::vector<value_t>& vals,
|
|
char sep = ',') {
|
|
std::istringstream str_stream(str);
|
|
std::string::size_type old_pos = 0;
|
|
std::string::size_type new_pos = 0;
|
|
|
|
// Iterate <sep>-delimited values
|
|
value_t val;
|
|
while ((new_pos = str.find(sep, old_pos)) != std::string::npos) {
|
|
if (new_pos != old_pos) {
|
|
str_stream.width(new_pos - old_pos);
|
|
str_stream >> val;
|
|
vals.push_back(val);
|
|
}
|
|
|
|
// skip over delimiter
|
|
str_stream.ignore(1);
|
|
old_pos = new_pos + 1;
|
|
}
|
|
|
|
// Read last value
|
|
str_stream >> val;
|
|
vals.push_back(val);
|
|
}
|
|
};
|
|
|
|
} // namespace cutlass
|