cutlass/tools/util/command_line.h

282 lines
8.4 KiB
C
Raw Normal View History

2018-05-17 02:44:56 +08:00
/******************************************************************************
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Utility for parsing command line arguments
*/
#include <iostream>
#include <limits>
#include <sstream>
#include <string>
#include <vector>
#include <cuda_runtime.h>
namespace cutlass {
/******************************************************************************
* command_line
******************************************************************************/
/**
* Utility for parsing command line arguments
*/
struct CommandLine {
std::vector<std::string> keys;
std::vector<std::string> values;
std::vector<std::string> args;
/**
* Constructor
*/
CommandLine(int argc, const char** argv) : keys(10), values(10) {
using namespace std;
for (int i = 1; i < argc; i++) {
string arg = argv[i];
if ((arg[0] != '-') || (arg[1] != '-')) {
args.push_back(arg);
continue;
}
string::size_type pos;
string key, val;
if ((pos = arg.find('=')) == string::npos) {
key = string(arg, 2, arg.length() - 2);
val = "";
} else {
key = string(arg, 2, pos - 2);
val = string(arg, pos + 1, arg.length() - 1);
}
keys.push_back(key);
values.push_back(val);
}
}
/**
* Checks whether a flag "--<flag>" is present in the commandline
*/
bool check_cmd_line_flag(const char* arg_name) const {
using namespace std;
for (int i = 0; i < int(keys.size()); ++i) {
if (keys[i] == string(arg_name)) return true;
}
return false;
}
/**
* Returns number of naked (non-flag and non-key-value) commandline parameters
*/
template <typename value_t>
int num_naked_args() const {
return args.size();
}
/**
* Returns the commandline parameter for a given index (not including flags)
*/
template <typename value_t>
void get_cmd_line_argument(int index, value_t& val) const {
using namespace std;
if (index < args.size()) {
istringstream str_stream(args[index]);
str_stream >> val;
}
}
/**
2018-09-19 07:58:03 +08:00
* Returns the boolean value specified for a given commandline parameter --<flag>=<bool>
2018-05-17 02:44:56 +08:00
*/
void get_cmd_line_argument(const char* arg_name, bool& val, bool _default = true) const {
val = _default;
if (check_cmd_line_flag(arg_name)) {
std::string value;
get_cmd_line_argument(arg_name, value);
val = !(value == "0" || value == "false");
}
}
/**
* Returns the value specified for a given commandline parameter --<flag>=<value>
*/
template <typename value_t>
void get_cmd_line_argument(const char* arg_name,
value_t& val,
value_t const& _default = value_t()) const {
using namespace std;
val = _default;
for (int i = 0; i < int(keys.size()); ++i) {
if (keys[i] == string(arg_name)) {
istringstream str_stream(values[i]);
str_stream >> val;
}
}
}
/**
* Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
*/
template <typename value_t>
void get_cmd_line_arguments(const char* arg_name,
std::vector<value_t>& vals,
char sep = ',') const {
using namespace std;
if (check_cmd_line_flag(arg_name)) {
// Clear any default values
vals.clear();
// Recover from multi-value string
for (int i = 0; i < keys.size(); ++i) {
if (keys[i] == string(arg_name)) {
string val_string(values[i]);
2018-09-19 07:58:03 +08:00
seperate_string(val_string, vals, sep);
2018-05-17 02:44:56 +08:00
}
}
}
}
/**
* Returns the values specified for a given commandline parameter
2018-09-19 07:58:03 +08:00
* --<flag>=<value>,<value_start:value_end>*
2018-05-17 02:44:56 +08:00
*/
void get_cmd_line_argument_pairs(const char* arg_name,
std::vector<std::pair<std::string, std::string> >& tokens,
char delim = ',',
char sep = ':') const {
if (check_cmd_line_flag(arg_name)) {
std::string value;
get_cmd_line_argument(arg_name, value);
tokenize(tokens, value, delim, sep);
}
}
2018-09-19 07:58:03 +08:00
/**
* Returns a list of ranges specified for a given commandline parameter
* --<flag>=<key:value>,<key:value>*
*/
void get_cmd_line_argument_ranges(const char* arg_name,
std::vector<std::vector<std::string> >& vals,
char delim = ',',
char sep = ':') const {
std::vector<std::string> ranges;
get_cmd_line_arguments(arg_name, ranges, delim);
for (std::vector<std::string>::const_iterator range = ranges.begin();
range != ranges.end(); ++range) {
std::vector<std::string> range_vals;
seperate_string(*range, range_vals, sep);
vals.push_back(range_vals);
}
}
2018-05-17 02:44:56 +08:00
/**
* The number of pairs parsed
*/
int parsed_argc() const { return (int)keys.size(); }
//-------------------------------------------------------------------------
// Utility functions
//-------------------------------------------------------------------------
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
std::string const& str,
char delim = ',',
char sep = ':') {
// Home-built to avoid Boost dependency
size_t s_idx = 0;
size_t d_idx = std::string::npos;
while (s_idx < str.size()) {
d_idx = str.find_first_of(delim, s_idx);
size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
size_t sep_idx = str.find_first_of(sep, s_idx);
size_t offset = 1;
if (sep_idx == std::string::npos || sep_idx >= end_idx) {
sep_idx = end_idx;
offset = 0;
}
std::pair<std::string, std::string> item(
str.substr(s_idx, sep_idx - s_idx),
str.substr(sep_idx + offset, end_idx - sep_idx - offset));
tokens.push_back(item);
s_idx = end_idx + 1;
}
}
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
static void tokenize(std::vector<std::string>& tokens,
std::string const& str,
char delim = ',',
char sep = ':') {
typedef std::vector<std::pair<std::string, std::string> > TokenVector;
typedef TokenVector::const_iterator token_iterator;
std::vector<std::pair<std::string, std::string> > token_pairs;
tokenize(token_pairs, str, delim, sep);
for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
tokens.push_back(tok->first);
}
}
2018-09-19 07:58:03 +08:00
template <typename value_t>
static void seperate_string(std::string const& str,
std::vector<value_t>& vals,
char sep = ',') {
std::istringstream str_stream(str);
std::string::size_type old_pos = 0;
std::string::size_type new_pos = 0;
// Iterate <sep>-delimited values
value_t val;
while ((new_pos = str.find(sep, old_pos)) != std::string::npos) {
if (new_pos != old_pos) {
str_stream.width(new_pos - old_pos);
str_stream >> val;
vals.push_back(val);
}
// skip over delimiter
str_stream.ignore(1);
old_pos = new_pos + 1;
}
// Read last value
str_stream >> val;
vals.push_back(val);
}
2018-05-17 02:44:56 +08:00
};
} // namespace cutlass