2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include <luci/Profile/CircleNodeID.h>
21 #include <arser/arser.h>
22 #include <vconone/vconone.h>
32 void print_version(void)
34 std::cout << "circle-opselector version " << vconone::get_string() << std::endl;
35 std::cout << vconone::get_copyright() << std::endl;
38 std::vector<std::string> split_into_vector(const std::string &str, const char &delim)
40 std::vector<std::string> ret;
41 std::istringstream is(str);
42 for (std::string item; std::getline(is, item, delim);)
47 // remove empty string
48 ret.erase(std::remove_if(ret.begin(), ret.end(), [](const std::string &s) { return s.empty(); }),
54 bool is_number(const std::string &s)
56 return !s.empty() && std::find_if(s.begin(), s.end(),
57 [](unsigned char c) { return !std::isdigit(c); }) == s.end();
60 bool is_number(const std::vector<std::string> &vec)
62 for (const auto &s : vec)
64 if (not::is_number(s))
73 * @brief Segmentation function for user's '--by_id' input
75 * @note This function tokenizes the input data.s
76 * First, divide it into ',', and if token has '-', devide it once more into '-'.
77 * For example, if user input is '12,34,56', it is devided into [12,34,56].
78 * If input is '1-2,34,56', it is devided into [[1,2],34,56].
79 * And '-' means range so, if input is '2-7', it means all integer between 2-7.
81 std::vector<uint32_t> split_id_input(const std::string &str)
83 std::vector<uint32_t> by_id;
85 // tokenize colon-separated string
86 auto colon_tokens = ::split_into_vector(str, ',');
87 if (colon_tokens.empty()) // input empty line like "".
89 std::cerr << "ERROR: Nothing was entered." << std::endl;
92 for (const auto &ctok : colon_tokens)
94 auto dash_tokens = ::split_into_vector(ctok, '-');
95 if (not::is_number(dash_tokens))
97 std::cerr << "ERROR: To select operator by id, please use these args: [0-9], '-', ','"
101 // convert string into integer
102 std::vector<uint32_t> int_tokens;
105 std::transform(dash_tokens.begin(), dash_tokens.end(), std::back_inserter(int_tokens),
106 [](const std::string &str) { return static_cast<uint32_t>(std::stoi(str)); });
108 catch (const std::out_of_range &)
110 // if input is big integer like '123467891234', stoi throw this exception.
111 std::cerr << "ERROR: Argument is out of range." << std::endl;
116 std::cerr << "ERROR: Unknown error" << std::endl;
120 switch (int_tokens.size())
122 case 0: // inputs like "-"
124 std::cerr << "ERROR: Nothing was entered" << std::endl;
127 case 1: // inputs like "1", "2"
129 by_id.push_back(int_tokens.at(0));
132 case 2: // inputs like "1-2", "11-50"
134 for (uint32_t i = int_tokens.at(0); i <= int_tokens.at(1); i++)
140 default: // inputs like "1-2-3"
142 std::cerr << "ERROR: Too many '-' in str." << std::endl;
151 std::vector<std::string> split_name_input(const std::string &str)
153 return ::split_into_vector(str, ',');
156 int entry(int argc, char **argv)
158 // TODO Add new option names!
160 arser::Arser arser("circle-opselector provides selecting operations in circle model");
162 arser.add_argument("--version")
164 .default_value(false)
165 .help("Show version information and exit")
166 .exit_with(print_version);
168 // TODO Add new options!
170 arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
171 arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
174 arser.add_argument("--by_id")
176 .type(arser::DataType::STR)
177 .help("Input operation id to select nodes.");
178 arser.add_argument("--by_name")
180 .type(arser::DataType::STR)
181 .help("Input operation name to select nodes.");
185 arser.parse(argc, argv);
187 catch (const std::runtime_error &err)
189 std::cerr << err.what() << std::endl;
194 std::string input_path = arser.get<std::string>("input");
195 std::string output_path = arser.get<std::string>("output");
197 std::string operator_input;
199 std::vector<uint32_t> by_id;
200 std::vector<std::string> by_name;
202 if (!arser["--by_id"] && !arser["--by_name"] || arser["--by_id"] && arser["--by_name"])
204 std::cerr << "ERROR: Either option '--by_id' or '--by_name' must be specified" << std::endl;
209 if (arser["--by_id"])
211 operator_input = arser.get<std::string>("--by_id");
212 by_id = split_id_input(operator_input);
214 if (arser["--by_name"])
216 operator_input = arser.get<std::string>("--by_name");
217 by_name = split_name_input(operator_input);
220 // Import original circle file.
221 auto module = opselector::getModule(input_path);
223 // Select nodes from user input.
224 std::vector<const luci::CircleNode *> selected_nodes;
226 // put selected nodes into vector.
229 loco::Graph *graph = module.get()->graph(0); // get main subgraph.
231 for (auto node : loco::all_nodes(graph))
233 auto cnode = loco::must_cast<const luci::CircleNode *>(node);
237 auto node_id = luci::get_node_id(cnode); // if the node is not operator, throw runtime_error
239 for (auto selected_id : by_id)
240 if (selected_id == node_id) // find the selected id
241 selected_nodes.emplace_back(cnode);
243 catch (std::runtime_error)
251 loco::Graph *graph = module.get()->graph(0); // get main subgraph.
253 for (auto node : loco::all_nodes(graph))
255 auto cnode = loco::must_cast<const luci::CircleNode *>(node);
256 std::string node_name = cnode->name();
258 for (auto selected_name : by_name)
259 if (selected_name.compare(node_name) == 0) // find the selected name
260 selected_nodes.emplace_back(cnode);
263 if (selected_nodes.size() == 0)
265 std::cerr << "ERROR: No operator selected" << std::endl;
268 // TODO implement node selections
270 // Export to output Circle file
271 assert(opselector::exportModule(module.get(), output_path));