b6fe89c8aa99162939542041c3dff643def4642a
[platform/upstream/dldt.git] / ngraph / core / src / graph_util.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <numeric>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "ngraph/descriptor/input.hpp"
23 #include "ngraph/descriptor/output.hpp"
24 #include "ngraph/function.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/log.hpp"
27 #include "ngraph/node.hpp"
28 #include "ngraph/op/broadcast.hpp"
29 #include "ngraph/op/constant.hpp"
30 #include "ngraph/op/parameter.hpp"
31 #include "ngraph/op/result.hpp"
32 #include "ngraph/op/tensor_iterator.hpp"
33 #include "ngraph/op/util/op_types.hpp"
34 #include "ngraph/opsets/opset5.hpp"
35 #include "ngraph/pass/manager.hpp"
36 #include "ngraph/pass/visualize_tree.hpp"
37 #include "ngraph/provenance.hpp"
38 #include "ngraph/rt_info.hpp"
39 #include "ngraph/util.hpp"
40
41 NGRAPH_SUPPRESS_DEPRECATED_START
42
43 using namespace std;
44 using namespace ngraph;
45
46 void ngraph::traverse_nodes(const std::shared_ptr<const Function> p,
47                             std::function<void(std::shared_ptr<Node>)> f)
48 {
49     traverse_nodes(p.get(), f);
50 }
51
52 void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f)
53 {
54     NodeVector nodes;
55
56     for (auto r : p->get_results())
57     {
58         nodes.push_back(r);
59     }
60
61     for (auto param : p->get_parameters())
62     {
63         nodes.push_back(param);
64     }
65
66     traverse_nodes(nodes, f);
67 }
68
69 void ngraph::traverse_nodes(const NodeVector& subgraph_results,
70                             std::function<void(std::shared_ptr<Node>)> f,
71                             const NodeVector& subgraph_params)
72 {
73     std::unordered_set<Node*> instances_seen;
74     std::stack<Node*, std::vector<Node*>> stack;
75     for (auto& node_ptr : subgraph_params)
76     {
77         instances_seen.insert(node_ptr.get());
78     }
79     for (auto& node_ptr : subgraph_results)
80     {
81         stack.push(node_ptr.get());
82     }
83
84     while (!stack.empty())
85     {
86         Node* n = stack.top();
87         stack.pop();
88         if (instances_seen.insert(n).second)
89         {
90             f(n->shared_from_this());
91             for (size_t i = 0; i < n->inputs().size(); i++)
92             {
93                 stack.push(n->get_input_node_ptr(i));
94             }
95
96             for (auto& cdep : n->get_control_dependencies())
97             {
98                 stack.push(cdep.get());
99             }
100         }
101     }
102 }
103
104 NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr<Node> node2)
105 {
106     std::unordered_set<std::shared_ptr<Node>> node1_args;
107
108     auto compute_node1_args = [&node1_args](const std::shared_ptr<Node> node) {
109         node1_args.insert(node);
110     };
111
112     traverse_nodes({node1}, compute_node1_args, NodeVector{});
113
114     std::unordered_set<std::shared_ptr<Node>> node2_args;
115
116     auto compute_node2_args = [&node2_args](const std::shared_ptr<Node> node) {
117         node2_args.insert(node);
118     };
119
120     traverse_nodes({node2}, compute_node2_args, NodeVector{});
121
122     NodeVector common_args;
123     for (const auto& e : node1_args)
124     {
125         if (node2_args.count(e) > 0)
126         {
127             common_args.push_back(e);
128         }
129     }
130
131     return common_args;
132 }
133
134 void ngraph::replace_node(std::shared_ptr<Node> target,
135                           std::shared_ptr<Node> replacement,
136                           const std::vector<int64_t>& output_order)
137 {
138     if (ngraph::op::is_output(target))
139     {
140         throw ngraph_error("Result nodes cannot be replaced.");
141     }
142
143     NGRAPH_CHECK(target->get_output_size() == output_order.size(),
144                  "Target output size: ",
145                  target->get_output_size(),
146                  " must be equal output_order size: ",
147                  output_order.size());
148
149     // Fix input/output descriptors
150     NGRAPH_CHECK(target->get_output_size() == replacement->get_output_size());
151
152     if (ngraph::get_provenance_enabled())
153     {
154         auto common_args = ngraph::find_common_args(target, replacement);
155
156         std::set<string> removed_subgraph_tags;
157
158         auto set_replacement_prov = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
159             for (auto tag : node->get_provenance_tags())
160             {
161                 removed_subgraph_tags.insert(tag);
162             }
163         };
164
165         traverse_nodes({target}, set_replacement_prov, common_args);
166         replacement->add_provenance_tags(removed_subgraph_tags);
167
168         auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
169             node->add_provenance_tags(removed_subgraph_tags);
170         };
171
172         traverse_nodes({replacement}, set_prov_new_nodes, common_args);
173     }
174
175     // For each of target's output O with replacement output O_rep:
176     //     For each O's connected downstream input I:
177     //         Change I's connected upstream output to O_rep
178     for (size_t i = 0; i < target->get_output_size(); i++)
179     {
180         for (auto& input : target->output(i).get_target_inputs())
181         {
182             input.replace_source_output(replacement->output(output_order[i]));
183         }
184     }
185
186     replacement->add_node_control_dependents(target);
187     target->clear_control_dependents();
188 }
189
190 void ngraph::replace_node(const std::shared_ptr<Node>& target,
191                           const OutputVector& replacement_values)
192 {
193     if (ngraph::op::is_output(target))
194     {
195         throw ngraph_error("Result nodes cannot be replaced.");
196     }
197
198     NGRAPH_CHECK(target->get_output_size() == replacement_values.size());
199
200     unordered_set<shared_ptr<Node>> replacement_nodes;
201     // For each of target's output O with replacement output O_rep:
202     //     For each O's connected downstream input I:
203     //         Change I's connected upstream output to O_rep
204     for (size_t i = 0; i < target->get_output_size(); i++)
205     {
206         auto& replacement_value = replacement_values.at(i);
207         auto replacement_node = replacement_value.get_node_shared_ptr();
208         if (replacement_nodes.find(replacement_node) == replacement_nodes.end())
209         {
210             replacement_node->add_node_control_dependents(target);
211             target->transfer_provenance_tags(replacement_node);
212             replacement_nodes.insert(replacement_node);
213         }
214         target->output(i).replace(replacement_values.at(i));
215     }
216     target->clear_control_dependents();
217 }
218
219 void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
220 {
221     auto default_output_order = vector<int64_t>(target->get_output_size());
222     std::iota(default_output_order.begin(), default_output_order.end(), 0);
223     replace_node(target, replacement, default_output_order);
224 }
225
226 void ngraph::replace_nodes(
227     const std::shared_ptr<Function>& f,
228     const unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Parameter>>&
229         parameter_replacement_map,
230     const unordered_map<shared_ptr<Node>, shared_ptr<Node>>& body_replacement_map)
231 {
232     auto& params = f->get_parameters();
233
234     for (size_t i = 0; i < params.size(); i++)
235     {
236         if (parameter_replacement_map.count(params[i]) != 0 &&
237             parameter_replacement_map.at(params[i]) != params[i])
238         {
239             f->replace_parameter(i, parameter_replacement_map.at(params[i]));
240         }
241     }
242
243     for (auto& kv : body_replacement_map)
244     {
245         auto& k = kv.first;
246         auto& v = kv.second;
247
248         if (k != v)
249         {
250             f->replace_node(k, v);
251         }
252     }
253 }
254
255 // Check if all paths from X to a result go through Y
256 bool ngraph::is_post_dominated(Node* X, Node* Y)
257 {
258     std::unordered_set<Node*> visited;
259     std::stack<Node*, std::vector<Node*>> stack;
260     stack.push(X);
261
262     while (stack.size() > 0)
263     {
264         ngraph::Node* curr = stack.top();
265         visited.insert(curr);
266         if (ngraph::op::is_output(curr))
267         {
268             return false;
269         }
270         stack.pop();
271         if (curr != Y)
272         {
273             for (const auto& next : curr->get_users())
274             {
275                 if (visited.count(next.get()) == 0)
276                 {
277                     stack.push(next.get());
278                 }
279             }
280         }
281     }
282     return true;
283 }
284
285 std::vector<std::shared_ptr<ngraph::Node>>
286     ngraph::clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
287 {
288     // for each node in topological order
289     auto sorted_nodes = topological_sort(nodes);
290     for (auto node : sorted_nodes)
291     {
292         if (node_map.count(node.get()) == 0)
293         {
294             // get (already) cloned arguments and clone the node
295             OutputVector cloned_args;
296             for (auto input : node->inputs())
297             {
298                 Output<Node> output = input.get_source_output();
299                 cloned_args.push_back(output.for_node(node_map.at(output.get_node())));
300             }
301             std::vector<std::shared_ptr<Node>> cloned_dependencies;
302             for (auto& dependency : node->get_control_dependencies())
303             {
304                 shared_ptr<Node>& dependent = node_map.at(dependency.get());
305                 if (find(cloned_dependencies.begin(), cloned_dependencies.end(), dependent) ==
306                     cloned_dependencies.end())
307                 {
308                     cloned_dependencies.push_back(dependent);
309                 }
310             }
311             auto cloned_node = node->copy_with_new_inputs(cloned_args, cloned_dependencies);
312             // There is a friendly name for this node so copy it
313             cloned_node->set_friendly_name(node->get_friendly_name());
314             //  TODO: workaround for shape inference, delete it after fix
315             if (std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(cloned_node))
316             {
317                 cloned_node->validate_and_infer_types();
318             }
319             auto rt_info = node->get_rt_info();
320             cloned_node->get_rt_info() = rt_info;
321
322             for (auto tag : node->get_provenance_tags())
323             {
324                 cloned_node->add_provenance_tag(tag);
325             }
326             cloned_node->set_op_annotations(node->get_op_annotations());
327
328             node_map[node.get()] = cloned_node;
329         }
330     }
331
332     // create and return vector of cloned nodes
333     // order matches input vector (not necessarily topological)
334     std::vector<std::shared_ptr<ngraph::Node>> cloned_nodes;
335     for (auto node : nodes)
336     {
337         cloned_nodes.push_back(node_map.at(node.get()));
338     }
339     return cloned_nodes;
340 }
341
342 std::list<std::shared_ptr<ngraph::Node>>
343     ngraph::clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
344                         RawNodeOutputMap& output_map)
345 {
346     // for each node in topological order
347     auto sorted_nodes = topological_sort(nodes);
348     std::list<shared_ptr<Node>> cloned_nodes;
349     for (auto node : sorted_nodes)
350     {
351         auto node_outputs = node->outputs();
352         for (auto value : node_outputs)
353         {
354             if (output_map.count(value) == 0)
355             {
356                 // We need this node cloned
357                 // get (already) cloned arguments and clone the node
358                 OutputVector cloned_args;
359                 for (auto value : node->input_values())
360                 {
361                     cloned_args.push_back(output_map.at(value));
362                 }
363                 NodeVector cloned_dependencies;
364                 for (auto& dependency : node->get_control_dependencies())
365                 {
366                     for (auto dependency_value : dependency->outputs())
367                     {
368                         shared_ptr<Node> dependent =
369                             output_map.at(dependency_value).get_node_shared_ptr();
370                         if (find(cloned_dependencies.begin(),
371                                  cloned_dependencies.end(),
372                                  dependent) == cloned_dependencies.end())
373                         {
374                             cloned_dependencies.push_back(dependent);
375                         }
376                     }
377                 }
378                 auto cloned_node = node->copy_with_new_inputs(cloned_args, cloned_dependencies);
379                 cloned_nodes.push_back(cloned_node);
380                 // There is a friendly name for this node so copy it
381                 cloned_node->set_friendly_name(node->get_friendly_name());
382                 //  TODO: workaround for shape inference, delete it after fix
383                 if (std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(cloned_node))
384                 {
385                     cloned_node->validate_and_infer_types();
386                 }
387                 auto rt_info = node->get_rt_info();
388                 cloned_node->get_rt_info() = rt_info;
389
390                 for (auto tag : node->get_provenance_tags())
391                 {
392                     cloned_node->add_provenance_tag(tag);
393                 }
394                 cloned_node->set_op_annotations(node->get_op_annotations());
395                 for (auto cloned_value : cloned_node->outputs())
396                 {
397                     auto original_value = node_outputs.at(cloned_value.get_index());
398                     if (output_map.count(original_value) == 0)
399                     {
400                         output_map[original_value] = cloned_value;
401                     }
402                 }
403                 break;
404             }
405         }
406     }
407     return cloned_nodes;
408 }
409
410 std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& func)
411 {
412     NodeMap nm;
413     return clone_function(func, nm);
414 }
415
416 std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& func,
417                                                          NodeMap& node_map)
418 {
419     // clone function operations
420     clone_nodes(func.get_ops(), node_map);
421
422     // get cloned function results and parameters
423     ResultVector cloned_results;
424     for (shared_ptr<Node> node : func.get_results())
425     {
426         auto result = as_type_ptr<op::Result>(node_map.at(node.get()));
427         if (!result)
428         {
429             throw ngraph_error("Results should be of type op::Result");
430         }
431         cloned_results.push_back(result);
432     }
433     std::vector<std::shared_ptr<op::Parameter>> cloned_params;
434     for (auto param : func.get_parameters())
435     {
436         cloned_params.push_back(as_type_ptr<op::Parameter>(node_map.at(param.get())));
437     }
438
439     // create and return cloned function
440     auto result = std::make_shared<ngraph::Function>(cloned_results, cloned_params);
441     result->set_friendly_name(func.get_friendly_name());
442     return result;
443 }
444
445 bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant)
446 {
447     if (auto rc = as_type_ptr<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
448     {
449         return (rc->get_all_data_elements_bitwise_identical() &&
450                 rc->convert_value_to_string(0) == const_value);
451     }
452     else
453     {
454         return false;
455     }
456 }
457
458 // Insert result and parameter node between src_node and dst_node by splitting the graph
459 //
460 // Before:                        |  After:
461 // (Device:0)         (Device:1)  |  (Device:0)         (Device:0)  (Device:1)         (Device:1)
462 // +-----+---+       +---+-----+  |  +-----+---+       +---+-----+  +-----+---+       +---+-----+
463 // |     |   |       |   |     |  |  |     |   |       |   |     |  |     |   |       |   |     |
464 // |     | o +--[0]--> i |     |  |  |     | o +--[4]--> i |     |  |     | o +--[8]--> i |     |
465 // |     |   <--[1]--+   |     |  |  |     |   <--[5]--+   |     |  |     |   <--[9]--+   |     |
466 // | src +---+       +---+ dst |  |  | src +---+       +---+ res |  | par +---+       +---+ dst |
467 // |     |               |     |  |  |     |               |     |  |     |               |     |
468 // |     +------[2]------>     |  |  |     +------[6]------>     |  |     +------[10]----->     |
469 // |     <------[3]------+     |  |  |     <------[7]------+     |  |     <------[11]-----+     |
470 // +-----+               +-----+  |  +-----+               +-----+  +-----+               +-----+
471 pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
472     ngraph::insert_result_parameter_split(const shared_ptr<Node>& src_node,
473                                           const shared_ptr<Node>& dst_node)
474 {
475     if (src_node->get_output_size() != 1)
476     {
477         throw ngraph_error("Multiple output per op not supported in graph partition yet.");
478     }
479
480     // Make parameter node
481     shared_ptr<op::Parameter> par_node = make_shared<op::Parameter>(
482         src_node->get_output_element_type(0), src_node->get_output_shape(0));
483
484     // Fix input / output among src, dst and par
485     std::vector<Input<Node>> dst_inputs = get_inputs_from(*src_node, *dst_node);
486     NGRAPH_CHECK(dst_inputs.size() == 1,
487                  "insert_result_parameter_split encountered more than "
488                  "one input between the source and destination nodes");
489     auto& dst_input = dst_inputs[0];
490
491     std::vector<Output<Node>> src_outputs = get_outputs_to(*src_node, *dst_node);
492     NGRAPH_CHECK(src_outputs.size() == 1,
493                  "insert_result_parameter_split encountered more than "
494                  "one output between the source and destination nodes");
495     auto& src_output = src_outputs[0];
496
497     // Remove [0]
498     src_output.remove_target_input(dst_input);
499
500     // Remove [0] (again), add [8], remove [1], add [9]
501     dst_input.replace_source_output(par_node->output(0));
502
503     // Add res node
504     // Add [4], [5], [6], [7]
505     shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node);
506
507     return make_pair(res_node, par_node);
508 }
509
510 // Insert unary node between two nodes like S->D => S->N->D
511 // Before:                        |  After:
512 // +-----+---+       +---+-----+  |  +-----+---+       +---+-----+---+       +---+-----+
513 // |     |   |       |   |     |  |  |     |   |       |   |     |   |       |   |     |
514 // |     | o +--[0]--> i |     |  |  |     | o +--[4]--> i |     | o +--[8]--> i |     |
515 // |     |   <--[1]--+   |     |  |  |     |   <--[5]--+   |     |   <--[9]--+   |     |
516 // | src +---+       +---+ dst |  |  | src +---+       +---+ new +---+       +---+ dst |
517 // |     |               |     |  |  |     |               |     |               |     |
518 // |     +------[2]------>     |  |  |     +------[6]------>     +------[10]----->     |
519 // |     <------[3]------+     |  |  |     <------[7]------+     <------[11]-----+     |
520 // +-----+               +-----+  |  +-----+               +-----+               +-----+
521 //                                |
522 // +-----+---+       +---+-----+  |
523 // |     |   |       |   |     |  |
524 // |     | o +--[4]--> i |     |  |
525 // |     |   <--[5]--+   |     |  |
526 // | src +---+       +---+ new |  |
527 // |     |               |     |  |
528 // |     +------[6]------>     |  |
529 // |     <------[7]------+     |  |
530 // +-----+               +-----+  |
531 //
532 // This cannot be achieved by ngraph::replace_node().
533 // With replace_node(), we could do:
534 // [     S           S      ]
535 // [    / \          |      ]
536 // [   /   \   =>    N      ]
537 // [  /     \       / \     ]
538 // [ D0     D1    D0   D1   ]
539 //
540 // But we want:
541 // [     S            S     ]
542 // [    / \          / \    ]
543 // [   /   \   =>   N0  N1  ]
544 // [  /     \      /     \  ]
545 // [ D0     D1    D0     D1 ]
546 //
547 // Typically new_node is connected to src_node already. The reason we don't create `new_node`
548 // inside the function and return it (similar to ngraph::insert_result_parameter_split) is that
549 // we'll have to templatize its function to call new_node's constructor.
550 void ngraph::insert_new_node_between(const shared_ptr<Node>& src_node,
551                                      const shared_ptr<Node>& dst_node,
552                                      const shared_ptr<Node>& new_node)
553 {
554     // Fix input / output
555     std::vector<Input<Node>> dst_inputs = get_inputs_from(*src_node, *dst_node);
556     NGRAPH_CHECK(dst_inputs.size() == 1,
557                  "insert_new_node_between encountered more than one "
558                  "input between the source and destination nodes");
559     auto& dst_input = dst_inputs[0];
560
561     std::vector<Output<Node>> src_outputs = get_outputs_to(*src_node, *dst_node);
562     NGRAPH_CHECK(src_outputs.size() == 1,
563                  "insert_new_node_between encountered more than one "
564                  "output between the source and destination nodes");
565     auto& src_output = src_outputs[0];
566
567     src_output.remove_target_input(dst_input); // Remove [0]
568     dst_input.replace_source_output(
569         new_node->output(0)); // Remove [0] (again), add [8], remove [1], add [9]
570 }
571
572 std::shared_ptr<Node> ngraph::make_zero(const element::Type& element_type, const Shape& shape)
573 {
574     std::shared_ptr<Node> zero = op::Constant::create(element_type, Shape{}, {0.0});
575     if (shape.size() > 0)
576     {
577         AxisSet axes;
578         for (size_t i = 0; i < shape.size(); i++)
579         {
580             axes.insert(i);
581         }
582         zero = std::make_shared<op::Broadcast>(zero, shape, axes);
583     }
584     return zero;
585 }
586
587 std::shared_ptr<Node> ngraph::make_constant_from_string(std::string val,
588                                                         const element::Type& element_type,
589                                                         const Shape& shape)
590 {
591     auto cvals = std::vector<std::string>(shape_size(shape), val);
592     return std::make_shared<op::Constant>(element_type, shape, cvals);
593 }
594
595 bool ngraph::is_zero(const Output<Node>& reduce_constant)
596 {
597     auto result_bool = is_equal_to_const_value("0", reduce_constant);
598     return result_bool;
599 }
600
601 bool ngraph::is_one(const Output<Node>& reduce_constant)
602 {
603     auto result_bool = is_equal_to_const_value("1", reduce_constant);
604     return result_bool;
605 }
606
607 NodeVector ngraph::get_subgraph_outputs(const NodeVector& nodes,
608                                         const NodeVector& exclusions,
609                                         bool ignore_unused,
610                                         bool ignore_output_duplicates)
611 {
612     std::set<shared_ptr<Node>> exclusions_set(exclusions.begin(), exclusions.end());
613     std::set<shared_ptr<Node>> nodes_set(nodes.begin(), nodes.end());
614
615     NodeVector outputs;
616
617     for (auto n : nodes)
618     {
619         if (exclusions_set.count(n) != 0)
620         {
621             continue;
622         }
623
624         for (const auto& u : n->get_users())
625         {
626             bool add_output = nodes_set.count(u) == 0 && (!ignore_unused || is_used(u.get()));
627             // check if output is already captured
628             add_output &= (ignore_output_duplicates ||
629                            std::find(outputs.begin(), outputs.end(), n) == outputs.end());
630             if (add_output)
631             {
632                 outputs.push_back(n);
633             }
634         }
635     }
636     return outputs;
637 }
638
639 NodeVector ngraph::extract_subgraph(const NodeVector& results, const NodeVector& args)
640 {
641     NodeVector subgraph;
642     traverse_nodes(results, [&](std::shared_ptr<Node> n) { subgraph.push_back(n); }, args);
643     return subgraph;
644 }
645
646 bool ngraph::is_used(Node* node)
647 {
648     std::unordered_set<Node*> instances_seen;
649     std::stack<Node*, std::vector<Node*>> stack;
650     stack.push(node);
651
652     while (stack.size() > 0)
653     {
654         ngraph::Node* n = stack.top();
655         if (instances_seen.count(n) == 0)
656         {
657             if (ngraph::op::is_output(n))
658             {
659                 return true;
660             }
661             instances_seen.insert(n);
662         }
663         stack.pop();
664         for (const auto& arg : n->get_users())
665         {
666             if (instances_seen.count(arg.get()) == 0)
667             {
668                 stack.push(arg.get());
669             }
670         }
671     }
672     return false;
673 }
674
675 size_t ngraph::get_user_count(Node* node)
676 {
677     size_t count = 0;
678     for (const auto& node_user : node->get_users())
679     {
680         count += is_used(node_user.get());
681     }
682     return count;
683 }
684
685 bool ngraph::possibly_overwritten(Node* node)
686 {
687     for (auto& output : node->outputs())
688     {
689         for (auto& input : output.get_target_inputs())
690         {
691             if (op::is_op(input.get_node()))
692             {
693                 auto op = static_cast<ngraph::op::Op*>(input.get_node());
694                 if (auto op_annotations = op->get_op_annotations())
695                 {
696                     for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
697                     {
698                         if (input.get_index() == oi_pair.input && oi_pair.destructive)
699                         {
700                             return true;
701                         }
702                     }
703                 }
704             }
705         }
706     }
707     return false;
708 }
709
710 bool ngraph::is_strided(const Strides& strides)
711 {
712     return std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 1; });
713 }
714
715 bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks)
716 {
717     auto node_rank = node->get_shape().size();
718     for (auto rank : valid_ranks)
719     {
720         if (rank == node_rank)
721         {
722             return true;
723         }
724     }
725     return false;
726 }
727
728 bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
729 {
730     if (!(op::is_constant(n1) && op::is_constant(n2)))
731     {
732         return false;
733     }
734
735     if (static_pointer_cast<op::Constant>(n1)->get_value_strings() !=
736         static_pointer_cast<op::Constant>(n2)->get_value_strings())
737     {
738         return false;
739     }
740
741     return true;
742 }
743
744 void ngraph::plot_graph(
745     std::shared_ptr<Function> f,
746     const std::string& filename,
747     std::function<void(const Node& node, std::vector<std::string>& attributes)> attributes)
748 {
749     ngraph::pass::Manager pass_manager;
750     pass_manager.register_pass<ngraph::pass::VisualizeTree>(filename, attributes);
751     pass_manager.run_passes(f);
752 }
753
754 std::vector<Input<Node>> ngraph::get_inputs_from(Node& src, Node& dst)
755 {
756     std::vector<Input<Node>> result;
757
758     for (auto& input : dst.inputs())
759     {
760         if (input.get_source_output().get_node() == &src)
761         {
762             result.push_back(input);
763         }
764     }
765
766     return result;
767 }
768
769 std::vector<Output<Node>> ngraph::get_outputs_to(Node& src, Node& dst)
770 {
771     std::vector<Output<Node>> result;
772
773     for (auto& output : src.outputs())
774     {
775         bool targets_dst = false;
776
777         for (auto& input : output.get_target_inputs())
778         {
779             if (input.get_node() == &dst)
780             {
781                 targets_dst = true;
782                 break;
783             }
784         }
785
786         if (targets_dst)
787         {
788             result.push_back(output);
789         }
790     }
791
792     return result;
793 }
794
795 static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node,
796                                   std::deque<std::shared_ptr<ngraph::Node>>& path,
797                                   std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
798                                   ngraph::NodeVector& cycle_nodes)
799 {
800     path.push_back(node);
801     path_set.insert(node);
802     for (size_t i = 0; i < node->inputs().size(); i++)
803     {
804         auto arg = node->get_input_node_shared_ptr(i);
805         if (path_set.find(arg) != path_set.end())
806         {
807             for (auto it : path)
808             {
809                 cycle_nodes.push_back(it);
810             }
811             // last node
812             cycle_nodes.push_back(arg);
813             return true;
814         }
815         if (check_for_cycles_bkwd(arg, path, path_set, cycle_nodes))
816         {
817             return true;
818         }
819     }
820     path_set.erase(path.back());
821     path.pop_back();
822     return false;
823 }
824
825 static bool check_for_cycles_fwd(std::shared_ptr<ngraph::Node> node,
826                                  std::deque<std::shared_ptr<ngraph::Node>>& path,
827                                  std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
828                                  ngraph::NodeVector& cycle_nodes)
829 {
830     path.push_back(node);
831     path_set.insert(node);
832     for (auto& arg : node->get_users())
833     {
834         if (path_set.find(arg) != path_set.end())
835         {
836             for (auto it : path)
837             {
838                 cycle_nodes.push_back(it);
839             }
840             // last node
841             cycle_nodes.push_back(arg);
842             return true;
843         }
844         if (check_for_cycles_fwd(arg, path, path_set, cycle_nodes))
845         {
846             return true;
847         }
848     }
849     path_set.erase(path.back());
850     path.pop_back();
851     return false;
852 }
853
854 bool ngraph::check_for_cycles(const ngraph::Function* func,
855                               ngraph::NodeVector& cycle_nodes,
856                               bool& is_bkwd_cycle)
857 {
858     for (auto res : func->get_results())
859     {
860         std::deque<std::shared_ptr<Node>> path;
861         // mirror of path stack for faster cycle check
862         std::unordered_set<std::shared_ptr<Node>> path_set;
863         if (check_for_cycles_bkwd(res, path, path_set, cycle_nodes))
864         {
865             is_bkwd_cycle = true;
866             return true;
867         }
868     }
869
870     for (auto param : func->get_parameters())
871     {
872         std::deque<std::shared_ptr<Node>> path;
873         // mirror of path stack for faster cycle check
874         std::unordered_set<std::shared_ptr<Node>> path_set;
875         if (check_for_cycles_fwd(param, path, path_set, cycle_nodes))
876         {
877             is_bkwd_cycle = false;
878             return true;
879         }
880     }
881     // no cycles
882     return false;
883 }
884
885 bool ngraph::replace_output_update_name(Output<Node> output, const Output<Node>& replacement)
886 {
887     bool has_result_output = false;
888     for (auto& target_input : output.get_target_inputs())
889     {
890         if (is_type<op::Result>(target_input.get_node()))
891         {
892             // ignore trivial elimination
893             has_result_output = true;
894             if (is_type<ngraph::op::Parameter>(replacement.get_node()))
895             {
896                 return false;
897             }
898             break;
899         }
900     }
901     if (!has_result_output || replacement.get_node()->get_users().size() == 1)
902     {
903         if (has_result_output && !is_type<ngraph::op::Parameter>(replacement.get_node()))
904         {
905             replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
906             // Update output tensor name
907             replacement.get_tensor().set_name(output.get_node()->get_friendly_name());
908         }
909         output.replace(replacement);
910         copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
911                           replacement.get_node_shared_ptr());
912         return true;
913     }
914     return false;
915 }
916
917 bool ngraph::replace_node_update_name(std::shared_ptr<Node> target,
918                                       std::shared_ptr<Node> replacement)
919 {
920     for (auto& output : target->output(0).get_target_inputs())
921     {
922         if (as_type<ngraph::op::Parameter>(replacement->input_value(0).get_node()) &&
923             as_type<op::Result>(output.get_node()))
924         {
925             return false;
926         }
927     }
928     replace_node(target, replacement);
929     replacement->set_friendly_name(target->get_friendly_name());
930     copy_runtime_info(target, replacement);
931     return true;
932 }