Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / circle-opselector / driver / Driver.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ModuleIO.h"
18
19 #include <luci/Profile/CircleNodeID.h>
20
21 #include <arser/arser.h>
22 #include <vconone/vconone.h>
23
24 #include <iostream>
25 #include <string>
26 #include <vector>
27 #include <algorithm>
28 #include <cctype>
29 #include <numeric>
30 #include <sstream>
31
32 void print_version(void)
33 {
34   std::cout << "circle-opselector version " << vconone::get_string() << std::endl;
35   std::cout << vconone::get_copyright() << std::endl;
36 }
37
38 std::vector<std::string> split_into_vector(const std::string &str, const char &delim)
39 {
40   std::vector<std::string> ret;
41   std::istringstream is(str);
42   for (std::string item; std::getline(is, item, delim);)
43   {
44     ret.push_back(item);
45   }
46
47   // remove empty string
48   ret.erase(std::remove_if(ret.begin(), ret.end(), [](const std::string &s) { return s.empty(); }),
49             ret.end());
50
51   return ret;
52 }
53
54 bool is_number(const std::string &s)
55 {
56   return !s.empty() && std::find_if(s.begin(), s.end(),
57                                     [](unsigned char c) { return !std::isdigit(c); }) == s.end();
58 }
59
60 bool is_number(const std::vector<std::string> &vec)
61 {
62   for (const auto &s : vec)
63   {
64     if (not::is_number(s))
65     {
66       return false;
67     }
68   }
69   return true;
70 }
71
72 /**
73  * @brief  Segmentation function for user's '--by_id' input
74  *
75  * @note   This function tokenizes the input data.s
76  *         First, divide it into ',', and if token has '-', devide it once more into '-'.
77  *         For example, if user input is '12,34,56', it is devided into [12,34,56].
78  *         If input is '1-2,34,56', it is devided into [[1,2],34,56].
79  *         And '-' means range so, if input is '2-7', it means all integer between 2-7.
80  */
81 std::vector<uint32_t> split_id_input(const std::string &str)
82 {
83   std::vector<uint32_t> by_id;
84
85   // tokenize colon-separated string
86   auto colon_tokens = ::split_into_vector(str, ',');
87   if (colon_tokens.empty()) // input empty line like "".
88   {
89     std::cerr << "ERROR: Nothing was entered." << std::endl;
90     exit(EXIT_FAILURE);
91   }
92   for (const auto &ctok : colon_tokens)
93   {
94     auto dash_tokens = ::split_into_vector(ctok, '-');
95     if (not::is_number(dash_tokens))
96     {
97       std::cerr << "ERROR: To select operator by id, please use these args: [0-9], '-', ','"
98                 << std::endl;
99       exit(EXIT_FAILURE);
100     }
101     // convert string into integer
102     std::vector<uint32_t> int_tokens;
103     try
104     {
105       std::transform(dash_tokens.begin(), dash_tokens.end(), std::back_inserter(int_tokens),
106                      [](const std::string &str) { return static_cast<uint32_t>(std::stoi(str)); });
107     }
108     catch (const std::out_of_range &)
109     {
110       // if input is big integer like '123467891234', stoi throw this exception.
111       std::cerr << "ERROR: Argument is out of range." << std::endl;
112       exit(EXIT_FAILURE);
113     }
114     catch (...)
115     {
116       std::cerr << "ERROR: Unknown error" << std::endl;
117       exit(EXIT_FAILURE);
118     }
119
120     switch (int_tokens.size())
121     {
122       case 0: // inputs like "-"
123       {
124         std::cerr << "ERROR: Nothing was entered" << std::endl;
125         exit(EXIT_FAILURE);
126       }
127       case 1: // inputs like "1", "2"
128       {
129         by_id.push_back(int_tokens.at(0));
130         break;
131       }
132       case 2: // inputs like "1-2", "11-50"
133       {
134         for (uint32_t i = int_tokens.at(0); i <= int_tokens.at(1); i++)
135         {
136           by_id.push_back(i);
137         }
138         break;
139       }
140       default: // inputs like "1-2-3"
141       {
142         std::cerr << "ERROR: Too many '-' in str." << std::endl;
143         exit(EXIT_FAILURE);
144       }
145     }
146   }
147
148   return by_id;
149 }
150
151 std::vector<std::string> split_name_input(const std::string &str)
152 {
153   return ::split_into_vector(str, ',');
154 }
155
156 int entry(int argc, char **argv)
157 {
158   // TODO Add new option names!
159
160   arser::Arser arser("circle-opselector provides selecting operations in circle model");
161
162   arser.add_argument("--version")
163     .nargs(0)
164     .default_value(false)
165     .help("Show version information and exit")
166     .exit_with(print_version);
167
168   // TODO Add new options!
169
170   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
171   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
172
173   // select option
174   arser.add_argument("--by_id")
175     .nargs(1)
176     .type(arser::DataType::STR)
177     .help("Input operation id to select nodes.");
178   arser.add_argument("--by_name")
179     .nargs(1)
180     .type(arser::DataType::STR)
181     .help("Input operation name to select nodes.");
182
183   try
184   {
185     arser.parse(argc, argv);
186   }
187   catch (const std::runtime_error &err)
188   {
189     std::cerr << err.what() << std::endl;
190     std::cout << arser;
191     return EXIT_FAILURE;
192   }
193
194   std::string input_path = arser.get<std::string>("input");
195   std::string output_path = arser.get<std::string>("output");
196
197   std::string operator_input;
198
199   std::vector<uint32_t> by_id;
200   std::vector<std::string> by_name;
201
202   if (!arser["--by_id"] && !arser["--by_name"] || arser["--by_id"] && arser["--by_name"])
203   {
204     std::cerr << "ERROR: Either option '--by_id' or '--by_name' must be specified" << std::endl;
205     std::cerr << arser;
206     return EXIT_FAILURE;
207   }
208
209   if (arser["--by_id"])
210   {
211     operator_input = arser.get<std::string>("--by_id");
212     by_id = split_id_input(operator_input);
213   }
214   if (arser["--by_name"])
215   {
216     operator_input = arser.get<std::string>("--by_name");
217     by_name = split_name_input(operator_input);
218   }
219
220   // Import original circle file.
221   auto module = opselector::getModule(input_path);
222
223   // Select nodes from user input.
224   std::vector<const luci::CircleNode *> selected_nodes;
225
226   // put selected nodes into vector.
227   if (by_id.size())
228   {
229     loco::Graph *graph = module.get()->graph(0); // get main subgraph.
230
231     for (auto node : loco::all_nodes(graph))
232     {
233       auto cnode = loco::must_cast<const luci::CircleNode *>(node);
234
235       try
236       {
237         auto node_id = luci::get_node_id(cnode); // if the node is not operator, throw runtime_error
238
239         for (auto selected_id : by_id)
240           if (selected_id == node_id) // find the selected id
241             selected_nodes.emplace_back(cnode);
242       }
243       catch (std::runtime_error)
244       {
245         continue;
246       }
247     }
248   }
249   if (by_name.size())
250   {
251     loco::Graph *graph = module.get()->graph(0); // get main subgraph.
252
253     for (auto node : loco::all_nodes(graph))
254     {
255       auto cnode = loco::must_cast<const luci::CircleNode *>(node);
256       std::string node_name = cnode->name();
257
258       for (auto selected_name : by_name)
259         if (selected_name.compare(node_name) == 0) // find the selected name
260           selected_nodes.emplace_back(cnode);
261     }
262   }
263   if (selected_nodes.size() == 0)
264   {
265     std::cerr << "ERROR: No operator selected" << std::endl;
266     exit(EXIT_FAILURE);
267   }
268   // TODO implement node selections
269
270   // Export to output Circle file
271   assert(opselector::exportModule(module.get(), output_path));
272
273   return 0;
274 }