31cbf9c0cad9040c874f00c2b1f3d96bb5de1bdb
[platform/core/ml/nnfw.git] / compiler / circle-opselector / src / OpSelector.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "OpSelector.h"
18
19 #include <luci/ConnectNode.h>
20 #include <luci/Profile/CircleNodeID.h>
21 #include <luci/Service/CircleNodeClone.h>
22
23 #include <algorithm>
24 #include <cassert>
25 #include <sstream>
26 #include <string>
27 #include <vector>
28
29 namespace
30 {
31
32 /**
33  * @brief Tokenize given string
34  *
35  * Assumes given string looks like below.
36  *
37  * - '1,2,5,7,9'
38  * - '1-5,6,7,9,12-14'
39  * - 'tensor_a,tensor_b,tensor_d'
40  *
41  * NOTE. 1-5 is same with '1,2,3,4,5'.
42  *
43  * WARNING. SelectType::NAME doesn't allow '-' like 'tensor_a-tensor_c'.
44  */
45 std::vector<std::string> split_into_vector(const std::string &str, const char &delim)
46 {
47   std::vector<std::string> ret;
48   std::istringstream is(str);
49   for (std::string item; std::getline(is, item, delim);)
50   {
51     ret.push_back(item);
52   }
53
54   // Remove empty string
55   ret.erase(std::remove_if(ret.begin(), ret.end(), [](const std::string &s) { return s.empty(); }),
56             ret.end());
57
58   return ret;
59 }
60
61 bool is_number(const std::string &s)
62 {
63   return !s.empty() && std::find_if(s.begin(), s.end(),
64                                     [](unsigned char c) { return !std::isdigit(c); }) == s.end();
65 }
66
67 bool is_number(const std::vector<std::string> &vec)
68 {
69   for (const auto &s : vec)
70   {
71     if (not::is_number(s))
72     {
73       return false;
74     }
75   }
76   return true;
77 }
78
79 // TODO Move this class into a separate header for reuse
80 class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool>
81 {
82 public:
83   bool visit(const luci::CircleCustom *) final { return true; }
84   bool visit(const luci::CircleIf *) final { return true; }
85   bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; }
86   bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; }
87   bool visit(const luci::CircleSplit *) final { return true; }
88   bool visit(const luci::CircleSplitV *) final { return true; }
89   bool visit(const luci::CircleTopKV2 *) final { return true; }
90   bool visit(const luci::CircleUnique *) final { return true; }
91   bool visit(const luci::CircleUnpack *) final { return true; }
92   bool visit(const luci::CircleWhile *) final { return true; }
93   // default is false
94   bool visit(const luci::CircleNode *) final { return false; }
95 };
96
97 std::unique_ptr<loco::Graph> make_graph(const std::vector<const luci::CircleNode *> nodes)
98 {
99   auto graph = loco::make_graph();
100
101   luci::CloneContext ctx;
102   // clone nodes
103   for (const auto &n : nodes)
104   {
105     auto clone = luci::clone_node(n, graph.get());
106     ctx.emplace(n, clone);
107   }
108   // set graph input
109   for (const auto &n : nodes)
110   {
111     for (uint32_t i = 0; i < n->arity(); i++)
112     {
113       auto arg = n->arg(i);
114       auto input_node = dynamic_cast<luci::CircleNode *>(arg);
115       auto ctx_it = ctx.find(input_node);
116       // check if the node already has been cloned
117       if (ctx_it != ctx.end())
118         continue;
119       // the node isn't graph input if it is an other node's input
120       if (std::find(nodes.begin(), nodes.end(), arg) != nodes.end())
121         continue;
122       auto circle_const = dynamic_cast<luci::CircleConst *>(arg);
123       if (circle_const != nullptr)
124       {
125         auto clone = luci::clone_node(circle_const, graph.get());
126         ctx.emplace(circle_const, clone);
127       }
128       else
129       {
130         // circle input
131         auto circle_input = graph->nodes()->create<luci::CircleInput>();
132         input_node = dynamic_cast<luci::CircleNode *>(arg);
133         if (not input_node)
134         {
135           throw std::runtime_error{"ERROR: Invalid graph"};
136         }
137         luci::copy_common_attributes(input_node, circle_input);
138         ctx.emplace(input_node, circle_input);
139         // graph input
140         auto graph_input = graph->inputs()->create();
141         graph_input->name(circle_input->name());
142         graph_input->dtype(circle_input->dtype());
143         // graph input shape
144         auto input_shape = std::make_unique<loco::TensorShape>();
145         input_shape->rank(circle_input->rank());
146         for (uint32_t i = 0; i < circle_input->rank(); i++)
147         {
148           if (circle_input->dim(i).known())
149           {
150             circle_input->dim(i).set(circle_input->dim(i).value());
151           }
152         }
153         graph_input->shape(std::move(input_shape));
154
155         circle_input->index(graph_input->index());
156       }
157     }
158   }
159   // set graph output
160   for (auto &n : nodes)
161   {
162     auto outputs = loco::succs(n);
163     bool beingUsed = false;
164     for (const auto &o : outputs)
165     {
166       if (std::find(nodes.begin(), nodes.end(), o) != nodes.end())
167       {
168         beingUsed = true;
169         break;
170       }
171     }
172     // the node isn't graph output if it is an other node's output
173     if (beingUsed)
174       continue;
175
176     IsMultiOutputNode multiout_visitor;
177     bool isMultiOut = n->accept(&multiout_visitor);
178     for (auto &o : outputs)
179     {
180       const luci::CircleNode *output_node = nullptr;
181       if (isMultiOut)
182       {
183         output_node = dynamic_cast<const luci::CircleNode *>(o);
184         if (not output_node)
185         {
186           throw std::runtime_error{"ERROR: Invalid graph"};
187         }
188       }
189       else
190       {
191         output_node = n;
192       }
193       // circle output
194       auto circle_output = graph->nodes()->create<luci::CircleOutput>();
195       luci::copy_common_attributes(output_node, circle_output);
196       // connect to cloned output node
197       circle_output->from(ctx.find(output_node)->second);
198       // graph output
199       auto graph_output = graph->outputs()->create();
200       graph_output->name(output_node->name());
201       graph_output->dtype(output_node->dtype());
202       // graph output shape
203       auto output_shape = std::make_unique<loco::TensorShape>();
204       output_shape->rank(circle_output->rank());
205       for (uint32_t i = 0; i < output_shape->rank(); i++)
206       {
207         if (circle_output->dim(i).known())
208         {
209           output_shape->dim(i).set(circle_output->dim(i).value());
210         }
211       }
212       graph_output->shape(std::move(output_shape));
213
214       circle_output->index(graph_output->index());
215       if (not isMultiOut)
216         break;
217     }
218   }
219   // connect nodes
220   for (const auto &n : nodes)
221   {
222     luci::clone_connect(n, ctx);
223   }
224
225   return graph;
226 }
227
228 } // namespace
229
230 namespace opselector
231 {
232
233 OpSelector::OpSelector(const luci::Module *module) : _module{module}
234 {
235   if (_module->size() != 1)
236   {
237     throw std::runtime_error{"ERROR: Not support two or more subgraphs"};
238   }
239 }
240
241 template <>
242 std::vector<const luci::CircleNode *>
243 OpSelector::select_by<SelectType::ID>(const std::vector<std::string> &comma_tokens)
244 {
245   std::vector<uint32_t> by_id;
246
247   for (const auto &comma_token : comma_tokens)
248   {
249     auto dash_tokens = ::split_into_vector(comma_token, '-');
250     if (not::is_number(dash_tokens))
251     {
252       throw std::runtime_error{
253         "ERROR: To select operator by id, please use these args: [0-9], '-', ','"};
254     }
255
256     // Convert string into integer
257     std::vector<uint32_t> int_tokens;
258     try
259     {
260       std::transform(dash_tokens.begin(), dash_tokens.end(), std::back_inserter(int_tokens),
261                      [](const std::string &str) { return static_cast<uint32_t>(std::stoi(str)); });
262     }
263     catch (const std::out_of_range &)
264     {
265       // Uf input is big integer like '123467891234', stoi throws this exception.
266       throw std::runtime_error{"ERROR: Argument is out of range."};
267     }
268     catch (...)
269     {
270       throw std::runtime_error{"ERROR: Unknown error"};
271     }
272
273     switch (int_tokens.size())
274     {
275       case 0: // inputs like "-"
276       {
277         throw std::runtime_error{"ERROR: Nothing was entered"};
278       }
279       case 1: // inputs like "1", "2"
280       {
281         by_id.push_back(int_tokens.at(0));
282         break;
283       }
284       case 2: // inputs like "1-2", "11-50"
285       {
286         for (uint32_t i = int_tokens.at(0); i <= int_tokens.at(1); i++)
287         {
288           by_id.push_back(i);
289         }
290         break;
291       }
292       default: // inputs like "1-2-3"
293       {
294         throw std::runtime_error{"ERROR: Too many '-' in str."};
295       }
296     }
297   }
298
299   loco::Graph *graph = _module->graph(0);
300   std::vector<const luci::CircleNode *> selected_nodes;
301
302   for (auto node : loco::all_nodes(graph))
303   {
304     auto cnode = loco::must_cast<const luci::CircleNode *>(node);
305
306     try
307     {
308       auto node_id = luci::get_node_id(cnode);
309       for (auto selected_id : by_id)
310       {
311         if (selected_id == node_id)
312         {
313           selected_nodes.emplace_back(cnode);
314         }
315       }
316     }
317     catch (const std::runtime_error &)
318     {
319       continue;
320     }
321   }
322
323   return selected_nodes;
324 }
325
326 template <>
327 std::vector<const luci::CircleNode *>
328 OpSelector::select_by<SelectType::NAME>(const std::vector<std::string> &tokens)
329 {
330   loco::Graph *graph = _module->graph(0);
331   std::vector<const luci::CircleNode *> selected_nodes;
332
333   for (auto node : loco::all_nodes(graph))
334   {
335     auto cnode = loco::must_cast<const luci::CircleNode *>(node);
336     std::string node_name = cnode->name();
337
338     for (auto selected_name : tokens)
339       if (selected_name.compare(node_name) == 0) // find the selected name
340         selected_nodes.emplace_back(cnode);
341   }
342
343   return selected_nodes;
344 }
345
346 template <SelectType SELECT_TYPE>
347 std::unique_ptr<luci::Module> OpSelector::select_by(const std::string &str)
348 {
349   auto colon_tokens = ::split_into_vector(str, ',');
350   if (colon_tokens.empty())
351   {
352     throw std::runtime_error{"ERROR: Nothing was entered."};
353   }
354
355   assert(_module->size() == 1);
356
357   auto selected_nodes = select_by<SELECT_TYPE>(colon_tokens);
358
359   // multiout node should be considered
360   IsMultiOutputNode multiout_visitor;
361   std::vector<const luci::CircleNode *> output_nodes;
362   for (const auto &node : selected_nodes)
363   {
364     if (node->accept(&multiout_visitor))
365     {
366       auto outputs = loco::succs(node);
367       for (auto &o : outputs)
368       {
369         output_nodes.push_back(dynamic_cast<luci::CircleNode *>(o));
370       }
371     }
372   }
373   selected_nodes.insert(selected_nodes.end(), output_nodes.begin(), output_nodes.end());
374
375   auto new_module = std::make_unique<luci::Module>();
376   new_module->add(::make_graph(selected_nodes));
377
378   return new_module;
379 }
380
381 template std::unique_ptr<luci::Module>
382 OpSelector::select_by<SelectType::ID>(const std::string &str);
383
384 template std::unique_ptr<luci::Module>
385 OpSelector::select_by<SelectType::NAME>(const std::string &str);
386
387 } // namespace opselector