2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "OpSelector.h"
19 #include <luci/ConnectNode.h>
20 #include <luci/Profile/CircleNodeID.h>
21 #include <luci/Service/CircleNodeClone.h>
33 * @brief Tokenize given string
35 * Assumes given string looks like below.
39 * - 'tensor_a,tensor_b,tensor_d'
41 * NOTE. 1-5 is same with '1,2,3,4,5'.
43 * WARNING. SelectType::NAME doesn't allow '-' like 'tensor_a-tensor_c'.
45 std::vector<std::string> split_into_vector(const std::string &str, const char &delim)
47 std::vector<std::string> ret;
48 std::istringstream is(str);
49 for (std::string item; std::getline(is, item, delim);)
54 // Remove empty string
55 ret.erase(std::remove_if(ret.begin(), ret.end(), [](const std::string &s) { return s.empty(); }),
61 bool is_number(const std::string &s)
63 return !s.empty() && std::find_if(s.begin(), s.end(),
64 [](unsigned char c) { return !std::isdigit(c); }) == s.end();
67 bool is_number(const std::vector<std::string> &vec)
69 for (const auto &s : vec)
71 if (not::is_number(s))
79 // TODO Move this class into a separate header for reuse
80 class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool>
83 bool visit(const luci::CircleCustom *) final { return true; }
84 bool visit(const luci::CircleIf *) final { return true; }
85 bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; }
86 bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; }
87 bool visit(const luci::CircleSplit *) final { return true; }
88 bool visit(const luci::CircleSplitV *) final { return true; }
89 bool visit(const luci::CircleTopKV2 *) final { return true; }
90 bool visit(const luci::CircleUnique *) final { return true; }
91 bool visit(const luci::CircleUnpack *) final { return true; }
92 bool visit(const luci::CircleWhile *) final { return true; }
94 bool visit(const luci::CircleNode *) final { return false; }
97 std::unique_ptr<loco::Graph> make_graph(const std::vector<const luci::CircleNode *> nodes)
99 auto graph = loco::make_graph();
101 luci::CloneContext ctx;
103 for (const auto &n : nodes)
105 auto clone = luci::clone_node(n, graph.get());
106 ctx.emplace(n, clone);
109 for (const auto &n : nodes)
111 for (uint32_t i = 0; i < n->arity(); i++)
113 auto arg = n->arg(i);
114 auto input_node = dynamic_cast<luci::CircleNode *>(arg);
115 auto ctx_it = ctx.find(input_node);
116 // check if the node already has been cloned
117 if (ctx_it != ctx.end())
119 // the node isn't graph input if it is an other node's input
120 if (std::find(nodes.begin(), nodes.end(), arg) != nodes.end())
122 auto circle_const = dynamic_cast<luci::CircleConst *>(arg);
123 if (circle_const != nullptr)
125 auto clone = luci::clone_node(circle_const, graph.get());
126 ctx.emplace(circle_const, clone);
131 auto circle_input = graph->nodes()->create<luci::CircleInput>();
132 input_node = dynamic_cast<luci::CircleNode *>(arg);
135 throw std::runtime_error{"ERROR: Invalid graph"};
137 luci::copy_common_attributes(input_node, circle_input);
138 ctx.emplace(input_node, circle_input);
140 auto graph_input = graph->inputs()->create();
141 graph_input->name(circle_input->name());
142 graph_input->dtype(circle_input->dtype());
144 auto input_shape = std::make_unique<loco::TensorShape>();
145 input_shape->rank(circle_input->rank());
146 for (uint32_t i = 0; i < circle_input->rank(); i++)
148 if (circle_input->dim(i).known())
150 circle_input->dim(i).set(circle_input->dim(i).value());
153 graph_input->shape(std::move(input_shape));
155 circle_input->index(graph_input->index());
160 for (auto &n : nodes)
162 auto outputs = loco::succs(n);
163 bool beingUsed = false;
164 for (const auto &o : outputs)
166 if (std::find(nodes.begin(), nodes.end(), o) != nodes.end())
172 // the node isn't graph output if it is an other node's output
176 IsMultiOutputNode multiout_visitor;
177 bool isMultiOut = n->accept(&multiout_visitor);
178 for (auto &o : outputs)
180 const luci::CircleNode *output_node = nullptr;
183 output_node = dynamic_cast<const luci::CircleNode *>(o);
186 throw std::runtime_error{"ERROR: Invalid graph"};
194 auto circle_output = graph->nodes()->create<luci::CircleOutput>();
195 luci::copy_common_attributes(output_node, circle_output);
196 // connect to cloned output node
197 circle_output->from(ctx.find(output_node)->second);
199 auto graph_output = graph->outputs()->create();
200 graph_output->name(output_node->name());
201 graph_output->dtype(output_node->dtype());
202 // graph output shape
203 auto output_shape = std::make_unique<loco::TensorShape>();
204 output_shape->rank(circle_output->rank());
205 for (uint32_t i = 0; i < output_shape->rank(); i++)
207 if (circle_output->dim(i).known())
209 output_shape->dim(i).set(circle_output->dim(i).value());
212 graph_output->shape(std::move(output_shape));
214 circle_output->index(graph_output->index());
220 for (const auto &n : nodes)
222 luci::clone_connect(n, ctx);
233 OpSelector::OpSelector(const luci::Module *module) : _module{module}
235 if (_module->size() != 1)
237 throw std::runtime_error{"ERROR: Not support two or more subgraphs"};
242 std::vector<const luci::CircleNode *>
243 OpSelector::select_by<SelectType::ID>(const std::vector<std::string> &comma_tokens)
245 std::vector<uint32_t> by_id;
247 for (const auto &comma_token : comma_tokens)
249 auto dash_tokens = ::split_into_vector(comma_token, '-');
250 if (not::is_number(dash_tokens))
252 throw std::runtime_error{
253 "ERROR: To select operator by id, please use these args: [0-9], '-', ','"};
256 // Convert string into integer
257 std::vector<uint32_t> int_tokens;
260 std::transform(dash_tokens.begin(), dash_tokens.end(), std::back_inserter(int_tokens),
261 [](const std::string &str) { return static_cast<uint32_t>(std::stoi(str)); });
263 catch (const std::out_of_range &)
265 // Uf input is big integer like '123467891234', stoi throws this exception.
266 throw std::runtime_error{"ERROR: Argument is out of range."};
270 throw std::runtime_error{"ERROR: Unknown error"};
273 switch (int_tokens.size())
275 case 0: // inputs like "-"
277 throw std::runtime_error{"ERROR: Nothing was entered"};
279 case 1: // inputs like "1", "2"
281 by_id.push_back(int_tokens.at(0));
284 case 2: // inputs like "1-2", "11-50"
286 for (uint32_t i = int_tokens.at(0); i <= int_tokens.at(1); i++)
292 default: // inputs like "1-2-3"
294 throw std::runtime_error{"ERROR: Too many '-' in str."};
299 loco::Graph *graph = _module->graph(0);
300 std::vector<const luci::CircleNode *> selected_nodes;
302 for (auto node : loco::all_nodes(graph))
304 auto cnode = loco::must_cast<const luci::CircleNode *>(node);
308 auto node_id = luci::get_node_id(cnode);
309 for (auto selected_id : by_id)
311 if (selected_id == node_id)
313 selected_nodes.emplace_back(cnode);
317 catch (const std::runtime_error &)
323 return selected_nodes;
327 std::vector<const luci::CircleNode *>
328 OpSelector::select_by<SelectType::NAME>(const std::vector<std::string> &tokens)
330 loco::Graph *graph = _module->graph(0);
331 std::vector<const luci::CircleNode *> selected_nodes;
333 for (auto node : loco::all_nodes(graph))
335 auto cnode = loco::must_cast<const luci::CircleNode *>(node);
336 std::string node_name = cnode->name();
338 for (auto selected_name : tokens)
339 if (selected_name.compare(node_name) == 0) // find the selected name
340 selected_nodes.emplace_back(cnode);
343 return selected_nodes;
346 template <SelectType SELECT_TYPE>
347 std::unique_ptr<luci::Module> OpSelector::select_by(const std::string &str)
349 auto colon_tokens = ::split_into_vector(str, ',');
350 if (colon_tokens.empty())
352 throw std::runtime_error{"ERROR: Nothing was entered."};
355 assert(_module->size() == 1);
357 auto selected_nodes = select_by<SELECT_TYPE>(colon_tokens);
359 // multiout node should be considered
360 IsMultiOutputNode multiout_visitor;
361 std::vector<const luci::CircleNode *> output_nodes;
362 for (const auto &node : selected_nodes)
364 if (node->accept(&multiout_visitor))
366 auto outputs = loco::succs(node);
367 for (auto &o : outputs)
369 output_nodes.push_back(dynamic_cast<luci::CircleNode *>(o));
373 selected_nodes.insert(selected_nodes.end(), output_nodes.begin(), output_nodes.end());
375 auto new_module = std::make_unique<luci::Module>();
376 new_module->add(::make_graph(selected_nodes));
381 template std::unique_ptr<luci::Module>
382 OpSelector::select_by<SelectType::ID>(const std::string &str);
384 template std::unique_ptr<luci::Module>
385 OpSelector::select_by<SelectType::NAME>(const std::string &str);
387 } // namespace opselector