Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-opselector / src / OpSelector.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "OpSelector.h"
18
19 #include <luci/ConnectNode.h>
20 #include <luci/Profile/CircleNodeID.h>
21 #include <luci/Service/CircleNodeClone.h>
22
23 #include <algorithm>
24 #include <cassert>
25 #include <sstream>
26 #include <string>
27 #include <vector>
28
29 namespace
30 {
31
32 /**
33  * @brief Tokenize given string
34  *
35  * Assumes given string looks like below.
36  *
37  * - '1,2,5,7,9'
38  * - '1-5,6,7,9,12-14'
39  * - 'tensor_a,tensor_b,tensor_d'
40  *
41  * NOTE. 1-5 is same with '1,2,3,4,5'.
42  *
43  * WARNING. SelectType::NAME doesn't allow '-' like 'tensor_a-tensor_c'.
44  */
45 std::vector<std::string> split_into_vector(const std::string &str, const char &delim)
46 {
47   std::vector<std::string> ret;
48   std::istringstream is(str);
49   for (std::string item; std::getline(is, item, delim);)
50   {
51     ret.push_back(item);
52   }
53
54   // Remove empty string
55   ret.erase(std::remove_if(ret.begin(), ret.end(), [](const std::string &s) { return s.empty(); }),
56             ret.end());
57
58   return ret;
59 }
60
61 bool is_number(const std::string &s)
62 {
63   return !s.empty() && std::find_if(s.begin(), s.end(),
64                                     [](unsigned char c) { return !std::isdigit(c); }) == s.end();
65 }
66
67 bool is_number(const std::vector<std::string> &vec)
68 {
69   for (const auto &s : vec)
70   {
71     if (not::is_number(s))
72     {
73       return false;
74     }
75   }
76   return true;
77 }
78
79 // TODO Move this class into a separate header for reuse
80 class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool>
81 {
82 public:
83   bool visit(const luci::CircleCustom *) final { return true; }
84   bool visit(const luci::CircleIf *) final { return true; }
85   bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; }
86   bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; }
87   bool visit(const luci::CircleSplit *) final { return true; }
88   bool visit(const luci::CircleSplitV *) final { return true; }
89   bool visit(const luci::CircleTopKV2 *) final { return true; }
90   bool visit(const luci::CircleUnique *) final { return true; }
91   bool visit(const luci::CircleUnpack *) final { return true; }
92   bool visit(const luci::CircleWhile *) final { return true; }
93   // default is false
94   bool visit(const luci::CircleNode *) final { return false; }
95 };
96
97 std::unique_ptr<loco::Graph> make_graph(const std::vector<const luci::CircleNode *> nodes)
98 {
99   auto graph = loco::make_graph();
100
101   luci::CloneContext ctx;
102   // clone nodes
103   for (const auto &n : nodes)
104   {
105     auto clone = luci::clone_node(n, graph.get());
106     ctx.emplace(n, clone);
107   }
108   // set graph input
109   for (const auto &n : nodes)
110   {
111     for (uint32_t i = 0; i < n->arity(); i++)
112     {
113       auto arg = n->arg(i);
114       auto input_node = dynamic_cast<luci::CircleNode *>(arg);
115       auto ctx_it = ctx.find(input_node);
116       // check if the node already has been cloned
117       if (ctx_it != ctx.end())
118         continue;
119       // the node isn't graph input if it is an other node's input
120       if (std::find(nodes.begin(), nodes.end(), arg) != nodes.end())
121         continue;
122       auto circle_const = dynamic_cast<luci::CircleConst *>(arg);
123       if (circle_const != nullptr)
124       {
125         auto clone = luci::clone_node(circle_const, graph.get());
126         ctx.emplace(circle_const, clone);
127       }
128       else
129       {
130         // circle input
131         auto circle_input = graph->nodes()->create<luci::CircleInput>();
132         input_node = dynamic_cast<luci::CircleNode *>(arg);
133         if (not input_node)
134         {
135           throw std::runtime_error{"ERROR: Invalid graph"};
136         }
137         luci::copy_common_attributes(input_node, circle_input);
138         ctx.emplace(input_node, circle_input);
139         // graph input
140         auto graph_input = graph->inputs()->create();
141         graph_input->name(circle_input->name());
142         graph_input->dtype(circle_input->dtype());
143         // graph input shape
144         auto input_shape = std::make_unique<loco::TensorShape>();
145         input_shape->rank(circle_input->rank());
146         for (uint32_t i = 0; i < circle_input->rank(); i++)
147         {
148           if (circle_input->dim(i).known())
149           {
150             circle_input->dim(i).set(circle_input->dim(i).value());
151           }
152         }
153         graph_input->shape(std::move(input_shape));
154
155         circle_input->index(graph_input->index());
156       }
157     }
158   }
159
160   const auto original_graph = nodes.at(0)->graph();
161   const auto original_outputs = loco::output_nodes(const_cast<loco::Graph *>(original_graph));
162
163   // set graph output
164   for (auto &n : nodes)
165   {
166     auto outputs = loco::succs(n);
167     bool beingUsed = false;
168     for (const auto &o : outputs)
169     {
170       if (std::find(nodes.begin(), nodes.end(), o) != nodes.end())
171       {
172         beingUsed = true;
173         break;
174       }
175     }
176
177     bool originalOutput = false;
178     for (const auto &o : outputs)
179     {
180       if (std::find(original_outputs.begin(), original_outputs.end(), o) != original_outputs.end())
181       {
182         originalOutput = true;
183         break;
184       }
185     }
186
187     // the node isn't graph output if it is an other node's output
188     if (beingUsed and not originalOutput)
189       continue;
190
191     IsMultiOutputNode multiout_visitor;
192     bool isMultiOut = n->accept(&multiout_visitor);
193     for (auto &o : outputs)
194     {
195       const luci::CircleNode *output_node = nullptr;
196       if (isMultiOut)
197       {
198         output_node = dynamic_cast<const luci::CircleNode *>(o);
199         if (not output_node)
200         {
201           throw std::runtime_error{"ERROR: Invalid graph"};
202         }
203       }
204       else
205       {
206         output_node = n;
207       }
208       // circle output
209       auto circle_output = graph->nodes()->create<luci::CircleOutput>();
210       luci::copy_common_attributes(output_node, circle_output);
211       // connect to cloned output node
212       circle_output->from(ctx.find(output_node)->second);
213       // graph output
214       auto graph_output = graph->outputs()->create();
215       graph_output->name(output_node->name());
216       graph_output->dtype(output_node->dtype());
217       // graph output shape
218       auto output_shape = std::make_unique<loco::TensorShape>();
219       output_shape->rank(circle_output->rank());
220       for (uint32_t i = 0; i < output_shape->rank(); i++)
221       {
222         if (circle_output->dim(i).known())
223         {
224           output_shape->dim(i).set(circle_output->dim(i).value());
225         }
226       }
227       graph_output->shape(std::move(output_shape));
228
229       circle_output->index(graph_output->index());
230       if (not isMultiOut)
231         break;
232     }
233   }
234   // connect nodes
235   for (const auto &n : nodes)
236   {
237     luci::clone_connect(n, ctx);
238   }
239
240   return graph;
241 }
242
243 } // namespace
244
245 namespace opselector
246 {
247
248 OpSelector::OpSelector(const luci::Module *module) : _module{module}
249 {
250   if (_module->size() != 1)
251   {
252     throw std::runtime_error{"ERROR: Not support two or more subgraphs"};
253   }
254 }
255
256 template <>
257 std::vector<const luci::CircleNode *>
258 OpSelector::select_by<SelectType::ID>(const std::vector<std::string> &comma_tokens)
259 {
260   std::vector<uint32_t> by_id;
261
262   for (const auto &comma_token : comma_tokens)
263   {
264     auto dash_tokens = ::split_into_vector(comma_token, '-');
265     if (not::is_number(dash_tokens))
266     {
267       throw std::runtime_error{
268         "ERROR: To select operator by id, please use these args: [0-9], '-', ','"};
269     }
270
271     // Convert string into integer
272     std::vector<uint32_t> int_tokens;
273     try
274     {
275       std::transform(dash_tokens.begin(), dash_tokens.end(), std::back_inserter(int_tokens),
276                      [](const std::string &str) { return static_cast<uint32_t>(std::stoi(str)); });
277     }
278     catch (const std::out_of_range &)
279     {
280       // Uf input is big integer like '123467891234', stoi throws this exception.
281       throw std::runtime_error{"ERROR: Argument is out of range."};
282     }
283     catch (...)
284     {
285       throw std::runtime_error{"ERROR: Unknown error"};
286     }
287
288     switch (int_tokens.size())
289     {
290       case 0: // inputs like "-"
291       {
292         throw std::runtime_error{"ERROR: Nothing was entered"};
293       }
294       case 1: // inputs like "1", "2"
295       {
296         by_id.push_back(int_tokens.at(0));
297         break;
298       }
299       case 2: // inputs like "1-2", "11-50"
300       {
301         for (uint32_t i = int_tokens.at(0); i <= int_tokens.at(1); i++)
302         {
303           by_id.push_back(i);
304         }
305         break;
306       }
307       default: // inputs like "1-2-3"
308       {
309         throw std::runtime_error{"ERROR: Too many '-' in str."};
310       }
311     }
312   }
313
314   loco::Graph *graph = _module->graph(0);
315   std::vector<const luci::CircleNode *> selected_nodes;
316
317   for (auto node : loco::all_nodes(graph))
318   {
319     auto cnode = loco::must_cast<const luci::CircleNode *>(node);
320
321     try
322     {
323       auto node_id = luci::get_node_id(cnode);
324       for (auto selected_id : by_id)
325       {
326         if (selected_id == node_id)
327         {
328           selected_nodes.emplace_back(cnode);
329         }
330       }
331     }
332     catch (const std::runtime_error &)
333     {
334       continue;
335     }
336   }
337
338   return selected_nodes;
339 }
340
341 template <>
342 std::vector<const luci::CircleNode *>
343 OpSelector::select_by<SelectType::NAME>(const std::vector<std::string> &tokens)
344 {
345   loco::Graph *graph = _module->graph(0);
346   std::vector<const luci::CircleNode *> selected_nodes;
347
348   for (auto node : loco::all_nodes(graph))
349   {
350     auto cnode = loco::must_cast<const luci::CircleNode *>(node);
351     std::string node_name = cnode->name();
352
353     for (const auto &selected_name : tokens)
354       if (selected_name.compare(node_name) == 0) // find the selected name
355         selected_nodes.emplace_back(cnode);
356   }
357
358   return selected_nodes;
359 }
360
361 template <SelectType SELECT_TYPE>
362 std::unique_ptr<luci::Module> OpSelector::select_by(const std::string &str)
363 {
364   auto colon_tokens = ::split_into_vector(str, ',');
365   if (colon_tokens.empty())
366   {
367     throw std::runtime_error{"ERROR: Nothing was entered."};
368   }
369
370   assert(_module->size() == 1);
371
372   auto selected_nodes = select_by<SELECT_TYPE>(colon_tokens);
373
374   // multiout node should be considered
375   IsMultiOutputNode multiout_visitor;
376   std::vector<const luci::CircleNode *> output_nodes;
377   for (const auto &node : selected_nodes)
378   {
379     if (node->accept(&multiout_visitor))
380     {
381       auto outputs = loco::succs(node);
382       for (auto &o : outputs)
383       {
384         output_nodes.push_back(dynamic_cast<luci::CircleNode *>(o));
385       }
386     }
387   }
388   selected_nodes.insert(selected_nodes.end(), output_nodes.begin(), output_nodes.end());
389
390   auto new_module = std::make_unique<luci::Module>();
391   new_module->add(::make_graph(selected_nodes));
392
393   return new_module;
394 }
395
396 template std::unique_ptr<luci::Module>
397 OpSelector::select_by<SelectType::ID>(const std::string &str);
398
399 template std::unique_ptr<luci::Module>
400 OpSelector::select_by<SelectType::NAME>(const std::string &str);
401
402 } // namespace opselector