1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
18 #include <unordered_map>
19 #include <unordered_set>
22 #include "ngraph/descriptor/input.hpp"
23 #include "ngraph/descriptor/output.hpp"
24 #include "ngraph/function.hpp"
25 #include "ngraph/graph_util.hpp"
26 #include "ngraph/log.hpp"
27 #include "ngraph/node.hpp"
28 #include "ngraph/op/broadcast.hpp"
29 #include "ngraph/op/constant.hpp"
30 #include "ngraph/op/parameter.hpp"
31 #include "ngraph/op/result.hpp"
32 #include "ngraph/op/tensor_iterator.hpp"
33 #include "ngraph/op/util/op_types.hpp"
34 #include "ngraph/opsets/opset5.hpp"
35 #include "ngraph/pass/manager.hpp"
36 #include "ngraph/pass/visualize_tree.hpp"
37 #include "ngraph/provenance.hpp"
38 #include "ngraph/rt_info.hpp"
39 #include "ngraph/util.hpp"
41 NGRAPH_SUPPRESS_DEPRECATED_START
44 using namespace ngraph;
46 void ngraph::traverse_nodes(const std::shared_ptr<const Function> p,
47 std::function<void(std::shared_ptr<Node>)> f)
49 traverse_nodes(p.get(), f);
52 void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_ptr<Node>)> f)
56 for (auto r : p->get_results())
61 for (auto param : p->get_parameters())
63 nodes.push_back(param);
66 traverse_nodes(nodes, f);
69 void ngraph::traverse_nodes(const NodeVector& subgraph_results,
70 std::function<void(std::shared_ptr<Node>)> f,
71 const NodeVector& subgraph_params)
73 std::unordered_set<Node*> instances_seen;
74 std::stack<Node*, std::vector<Node*>> stack;
75 for (auto& node_ptr : subgraph_params)
77 instances_seen.insert(node_ptr.get());
79 for (auto& node_ptr : subgraph_results)
81 stack.push(node_ptr.get());
84 while (!stack.empty())
86 Node* n = stack.top();
88 if (instances_seen.insert(n).second)
90 f(n->shared_from_this());
91 for (size_t i = 0; i < n->inputs().size(); i++)
93 stack.push(n->get_input_node_ptr(i));
96 for (auto& cdep : n->get_control_dependencies())
98 stack.push(cdep.get());
104 NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr<Node> node2)
106 std::unordered_set<std::shared_ptr<Node>> node1_args;
108 auto compute_node1_args = [&node1_args](const std::shared_ptr<Node> node) {
109 node1_args.insert(node);
112 traverse_nodes({node1}, compute_node1_args, NodeVector{});
114 std::unordered_set<std::shared_ptr<Node>> node2_args;
116 auto compute_node2_args = [&node2_args](const std::shared_ptr<Node> node) {
117 node2_args.insert(node);
120 traverse_nodes({node2}, compute_node2_args, NodeVector{});
122 NodeVector common_args;
123 for (const auto& e : node1_args)
125 if (node2_args.count(e) > 0)
127 common_args.push_back(e);
134 void ngraph::replace_node(std::shared_ptr<Node> target,
135 std::shared_ptr<Node> replacement,
136 const std::vector<int64_t>& output_order)
138 if (ngraph::op::is_output(target))
140 throw ngraph_error("Result nodes cannot be replaced.");
143 NGRAPH_CHECK(target->get_output_size() == output_order.size(),
144 "Target output size: ",
145 target->get_output_size(),
146 " must be equal output_order size: ",
147 output_order.size());
149 // Fix input/output descriptors
150 NGRAPH_CHECK(target->get_output_size() == replacement->get_output_size());
152 if (ngraph::get_provenance_enabled())
154 auto common_args = ngraph::find_common_args(target, replacement);
156 std::set<string> removed_subgraph_tags;
158 auto set_replacement_prov = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
159 for (auto tag : node->get_provenance_tags())
161 removed_subgraph_tags.insert(tag);
165 traverse_nodes({target}, set_replacement_prov, common_args);
166 replacement->add_provenance_tags(removed_subgraph_tags);
168 auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
169 node->add_provenance_tags(removed_subgraph_tags);
172 traverse_nodes({replacement}, set_prov_new_nodes, common_args);
175 // For each of target's output O with replacement output O_rep:
176 // For each O's connected downstream input I:
177 // Change I's connected upstream output to O_rep
178 for (size_t i = 0; i < target->get_output_size(); i++)
180 for (auto& input : target->output(i).get_target_inputs())
182 input.replace_source_output(replacement->output(output_order[i]));
186 replacement->add_node_control_dependents(target);
187 target->clear_control_dependents();
190 void ngraph::replace_node(const std::shared_ptr<Node>& target,
191 const OutputVector& replacement_values)
193 if (ngraph::op::is_output(target))
195 throw ngraph_error("Result nodes cannot be replaced.");
198 NGRAPH_CHECK(target->get_output_size() == replacement_values.size());
200 unordered_set<shared_ptr<Node>> replacement_nodes;
201 // For each of target's output O with replacement output O_rep:
202 // For each O's connected downstream input I:
203 // Change I's connected upstream output to O_rep
204 for (size_t i = 0; i < target->get_output_size(); i++)
206 auto& replacement_value = replacement_values.at(i);
207 auto replacement_node = replacement_value.get_node_shared_ptr();
208 if (replacement_nodes.find(replacement_node) == replacement_nodes.end())
210 replacement_node->add_node_control_dependents(target);
211 target->transfer_provenance_tags(replacement_node);
212 replacement_nodes.insert(replacement_node);
214 target->output(i).replace(replacement_values.at(i));
216 target->clear_control_dependents();
219 void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
221 auto default_output_order = vector<int64_t>(target->get_output_size());
222 std::iota(default_output_order.begin(), default_output_order.end(), 0);
223 replace_node(target, replacement, default_output_order);
226 void ngraph::replace_nodes(
227 const std::shared_ptr<Function>& f,
228 const unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Parameter>>&
229 parameter_replacement_map,
230 const unordered_map<shared_ptr<Node>, shared_ptr<Node>>& body_replacement_map)
232 auto& params = f->get_parameters();
234 for (size_t i = 0; i < params.size(); i++)
236 if (parameter_replacement_map.count(params[i]) != 0 &&
237 parameter_replacement_map.at(params[i]) != params[i])
239 f->replace_parameter(i, parameter_replacement_map.at(params[i]));
243 for (auto& kv : body_replacement_map)
250 f->replace_node(k, v);
255 // Check if all paths from X to a result go through Y
256 bool ngraph::is_post_dominated(Node* X, Node* Y)
258 std::unordered_set<Node*> visited;
259 std::stack<Node*, std::vector<Node*>> stack;
262 while (stack.size() > 0)
264 ngraph::Node* curr = stack.top();
265 visited.insert(curr);
266 if (ngraph::op::is_output(curr))
273 for (const auto& next : curr->get_users())
275 if (visited.count(next.get()) == 0)
277 stack.push(next.get());
285 std::vector<std::shared_ptr<ngraph::Node>>
286 ngraph::clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
288 // for each node in topological order
289 auto sorted_nodes = topological_sort(nodes);
290 for (auto node : sorted_nodes)
292 if (node_map.count(node.get()) == 0)
294 // get (already) cloned arguments and clone the node
295 OutputVector cloned_args;
296 for (auto input : node->inputs())
298 Output<Node> output = input.get_source_output();
299 cloned_args.push_back(output.for_node(node_map.at(output.get_node())));
301 std::vector<std::shared_ptr<Node>> cloned_dependencies;
302 for (auto& dependency : node->get_control_dependencies())
304 shared_ptr<Node>& dependent = node_map.at(dependency.get());
305 if (find(cloned_dependencies.begin(), cloned_dependencies.end(), dependent) ==
306 cloned_dependencies.end())
308 cloned_dependencies.push_back(dependent);
311 auto cloned_node = node->copy_with_new_inputs(cloned_args, cloned_dependencies);
312 // There is a friendly name for this node so copy it
313 cloned_node->set_friendly_name(node->get_friendly_name());
314 // TODO: workaround for shape inference, delete it after fix
315 if (std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(cloned_node))
317 cloned_node->validate_and_infer_types();
319 auto rt_info = node->get_rt_info();
320 cloned_node->get_rt_info() = rt_info;
322 for (auto tag : node->get_provenance_tags())
324 cloned_node->add_provenance_tag(tag);
326 cloned_node->set_op_annotations(node->get_op_annotations());
328 node_map[node.get()] = cloned_node;
332 // create and return vector of cloned nodes
333 // order matches input vector (not necessarily topological)
334 std::vector<std::shared_ptr<ngraph::Node>> cloned_nodes;
335 for (auto node : nodes)
337 cloned_nodes.push_back(node_map.at(node.get()));
342 std::list<std::shared_ptr<ngraph::Node>>
343 ngraph::clone_nodes(const std::vector<std::shared_ptr<ngraph::Node>>& nodes,
344 RawNodeOutputMap& output_map)
346 // for each node in topological order
347 auto sorted_nodes = topological_sort(nodes);
348 std::list<shared_ptr<Node>> cloned_nodes;
349 for (auto node : sorted_nodes)
351 auto node_outputs = node->outputs();
352 for (auto value : node_outputs)
354 if (output_map.count(value) == 0)
356 // We need this node cloned
357 // get (already) cloned arguments and clone the node
358 OutputVector cloned_args;
359 for (auto value : node->input_values())
361 cloned_args.push_back(output_map.at(value));
363 NodeVector cloned_dependencies;
364 for (auto& dependency : node->get_control_dependencies())
366 for (auto dependency_value : dependency->outputs())
368 shared_ptr<Node> dependent =
369 output_map.at(dependency_value).get_node_shared_ptr();
370 if (find(cloned_dependencies.begin(),
371 cloned_dependencies.end(),
372 dependent) == cloned_dependencies.end())
374 cloned_dependencies.push_back(dependent);
378 auto cloned_node = node->copy_with_new_inputs(cloned_args, cloned_dependencies);
379 cloned_nodes.push_back(cloned_node);
380 // There is a friendly name for this node so copy it
381 cloned_node->set_friendly_name(node->get_friendly_name());
382 // TODO: workaround for shape inference, delete it after fix
383 if (std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(cloned_node))
385 cloned_node->validate_and_infer_types();
387 auto rt_info = node->get_rt_info();
388 cloned_node->get_rt_info() = rt_info;
390 for (auto tag : node->get_provenance_tags())
392 cloned_node->add_provenance_tag(tag);
394 cloned_node->set_op_annotations(node->get_op_annotations());
395 for (auto cloned_value : cloned_node->outputs())
397 auto original_value = node_outputs.at(cloned_value.get_index());
398 if (output_map.count(original_value) == 0)
400 output_map[original_value] = cloned_value;
410 std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& func)
413 return clone_function(func, nm);
416 std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& func,
419 // clone function operations
420 clone_nodes(func.get_ops(), node_map);
422 // get cloned function results and parameters
423 ResultVector cloned_results;
424 for (shared_ptr<Node> node : func.get_results())
426 auto result = as_type_ptr<op::Result>(node_map.at(node.get()));
429 throw ngraph_error("Results should be of type op::Result");
431 cloned_results.push_back(result);
433 std::vector<std::shared_ptr<op::Parameter>> cloned_params;
434 for (auto param : func.get_parameters())
436 cloned_params.push_back(as_type_ptr<op::Parameter>(node_map.at(param.get())));
439 // create and return cloned function
440 auto result = std::make_shared<ngraph::Function>(cloned_results, cloned_params);
441 result->set_friendly_name(func.get_friendly_name());
445 bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant)
447 if (auto rc = as_type_ptr<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
449 return (rc->get_all_data_elements_bitwise_identical() &&
450 rc->convert_value_to_string(0) == const_value);
458 // Insert result and parameter node between src_node and dst_node by splitting the graph
461 // (Device:0) (Device:1) | (Device:0) (Device:0) (Device:1) (Device:1)
462 // +-----+---+ +---+-----+ | +-----+---+ +---+-----+ +-----+---+ +---+-----+
463 // | | | | | | | | | | | | | | | | | | |
464 // | | o +--[0]--> i | | | | | o +--[4]--> i | | | | o +--[8]--> i | |
465 // | | <--[1]--+ | | | | | <--[5]--+ | | | | <--[9]--+ | |
466 // | src +---+ +---+ dst | | | src +---+ +---+ res | | par +---+ +---+ dst |
467 // | | | | | | | | | | | | |
468 // | +------[2]------> | | | +------[6]------> | | +------[10]-----> |
469 // | <------[3]------+ | | | <------[7]------+ | | <------[11]-----+ |
470 // +-----+ +-----+ | +-----+ +-----+ +-----+ +-----+
471 pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
472 ngraph::insert_result_parameter_split(const shared_ptr<Node>& src_node,
473 const shared_ptr<Node>& dst_node)
475 if (src_node->get_output_size() != 1)
477 throw ngraph_error("Multiple output per op not supported in graph partition yet.");
480 // Make parameter node
481 shared_ptr<op::Parameter> par_node = make_shared<op::Parameter>(
482 src_node->get_output_element_type(0), src_node->get_output_shape(0));
484 // Fix input / output among src, dst and par
485 std::vector<Input<Node>> dst_inputs = get_inputs_from(*src_node, *dst_node);
486 NGRAPH_CHECK(dst_inputs.size() == 1,
487 "insert_result_parameter_split encountered more than "
488 "one input between the source and destination nodes");
489 auto& dst_input = dst_inputs[0];
491 std::vector<Output<Node>> src_outputs = get_outputs_to(*src_node, *dst_node);
492 NGRAPH_CHECK(src_outputs.size() == 1,
493 "insert_result_parameter_split encountered more than "
494 "one output between the source and destination nodes");
495 auto& src_output = src_outputs[0];
498 src_output.remove_target_input(dst_input);
500 // Remove [0] (again), add [8], remove [1], add [9]
501 dst_input.replace_source_output(par_node->output(0));
504 // Add [4], [5], [6], [7]
505 shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node);
507 return make_pair(res_node, par_node);
510 // Insert unary node between two nodes like S->D => S->N->D
512 // +-----+---+ +---+-----+ | +-----+---+ +---+-----+---+ +---+-----+
513 // | | | | | | | | | | | | | | | | |
514 // | | o +--[0]--> i | | | | | o +--[4]--> i | | o +--[8]--> i | |
515 // | | <--[1]--+ | | | | | <--[5]--+ | | <--[9]--+ | |
516 // | src +---+ +---+ dst | | | src +---+ +---+ new +---+ +---+ dst |
517 // | | | | | | | | | | |
518 // | +------[2]------> | | | +------[6]------> +------[10]-----> |
519 // | <------[3]------+ | | | <------[7]------+ <------[11]-----+ |
520 // +-----+ +-----+ | +-----+ +-----+ +-----+
522 // +-----+---+ +---+-----+ |
524 // | | o +--[4]--> i | | |
525 // | | <--[5]--+ | | |
526 // | src +---+ +---+ new | |
528 // | +------[6]------> | |
529 // | <------[7]------+ | |
532 // This cannot be achieved by ngraph::replace_node().
533 // With replace_node(), we could do:
547 // Typically new_node is connected to src_node already. The reason we don't create `new_node`
548 // inside the function and return it (similar to ngraph::insert_result_parameter_split) is that
549 // we'll have to templatize its function to call new_node's constructor.
550 void ngraph::insert_new_node_between(const shared_ptr<Node>& src_node,
551 const shared_ptr<Node>& dst_node,
552 const shared_ptr<Node>& new_node)
554 // Fix input / output
555 std::vector<Input<Node>> dst_inputs = get_inputs_from(*src_node, *dst_node);
556 NGRAPH_CHECK(dst_inputs.size() == 1,
557 "insert_new_node_between encountered more than one "
558 "input between the source and destination nodes");
559 auto& dst_input = dst_inputs[0];
561 std::vector<Output<Node>> src_outputs = get_outputs_to(*src_node, *dst_node);
562 NGRAPH_CHECK(src_outputs.size() == 1,
563 "insert_new_node_between encountered more than one "
564 "output between the source and destination nodes");
565 auto& src_output = src_outputs[0];
567 src_output.remove_target_input(dst_input); // Remove [0]
568 dst_input.replace_source_output(
569 new_node->output(0)); // Remove [0] (again), add [8], remove [1], add [9]
572 std::shared_ptr<Node> ngraph::make_zero(const element::Type& element_type, const Shape& shape)
574 std::shared_ptr<Node> zero = op::Constant::create(element_type, Shape{}, {0.0});
575 if (shape.size() > 0)
578 for (size_t i = 0; i < shape.size(); i++)
582 zero = std::make_shared<op::Broadcast>(zero, shape, axes);
587 std::shared_ptr<Node> ngraph::make_constant_from_string(std::string val,
588 const element::Type& element_type,
591 auto cvals = std::vector<std::string>(shape_size(shape), val);
592 return std::make_shared<op::Constant>(element_type, shape, cvals);
595 bool ngraph::is_zero(const Output<Node>& reduce_constant)
597 auto result_bool = is_equal_to_const_value("0", reduce_constant);
601 bool ngraph::is_one(const Output<Node>& reduce_constant)
603 auto result_bool = is_equal_to_const_value("1", reduce_constant);
607 NodeVector ngraph::get_subgraph_outputs(const NodeVector& nodes,
608 const NodeVector& exclusions,
610 bool ignore_output_duplicates)
612 std::set<shared_ptr<Node>> exclusions_set(exclusions.begin(), exclusions.end());
613 std::set<shared_ptr<Node>> nodes_set(nodes.begin(), nodes.end());
619 if (exclusions_set.count(n) != 0)
624 for (const auto& u : n->get_users())
626 bool add_output = nodes_set.count(u) == 0 && (!ignore_unused || is_used(u.get()));
627 // check if output is already captured
628 add_output &= (ignore_output_duplicates ||
629 std::find(outputs.begin(), outputs.end(), n) == outputs.end());
632 outputs.push_back(n);
639 NodeVector ngraph::extract_subgraph(const NodeVector& results, const NodeVector& args)
642 traverse_nodes(results, [&](std::shared_ptr<Node> n) { subgraph.push_back(n); }, args);
646 bool ngraph::is_used(Node* node)
648 std::unordered_set<Node*> instances_seen;
649 std::stack<Node*, std::vector<Node*>> stack;
652 while (stack.size() > 0)
654 ngraph::Node* n = stack.top();
655 if (instances_seen.count(n) == 0)
657 if (ngraph::op::is_output(n))
661 instances_seen.insert(n);
664 for (const auto& arg : n->get_users())
666 if (instances_seen.count(arg.get()) == 0)
668 stack.push(arg.get());
675 size_t ngraph::get_user_count(Node* node)
678 for (const auto& node_user : node->get_users())
680 count += is_used(node_user.get());
685 bool ngraph::possibly_overwritten(Node* node)
687 for (auto& output : node->outputs())
689 for (auto& input : output.get_target_inputs())
691 if (op::is_op(input.get_node()))
693 auto op = static_cast<ngraph::op::Op*>(input.get_node());
694 if (auto op_annotations = op->get_op_annotations())
696 for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
698 if (input.get_index() == oi_pair.input && oi_pair.destructive)
710 bool ngraph::is_strided(const Strides& strides)
712 return std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 1; });
715 bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks)
717 auto node_rank = node->get_shape().size();
718 for (auto rank : valid_ranks)
720 if (rank == node_rank)
728 bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
730 if (!(op::is_constant(n1) && op::is_constant(n2)))
735 if (static_pointer_cast<op::Constant>(n1)->get_value_strings() !=
736 static_pointer_cast<op::Constant>(n2)->get_value_strings())
744 void ngraph::plot_graph(
745 std::shared_ptr<Function> f,
746 const std::string& filename,
747 std::function<void(const Node& node, std::vector<std::string>& attributes)> attributes)
749 ngraph::pass::Manager pass_manager;
750 pass_manager.register_pass<ngraph::pass::VisualizeTree>(filename, attributes);
751 pass_manager.run_passes(f);
754 std::vector<Input<Node>> ngraph::get_inputs_from(Node& src, Node& dst)
756 std::vector<Input<Node>> result;
758 for (auto& input : dst.inputs())
760 if (input.get_source_output().get_node() == &src)
762 result.push_back(input);
769 std::vector<Output<Node>> ngraph::get_outputs_to(Node& src, Node& dst)
771 std::vector<Output<Node>> result;
773 for (auto& output : src.outputs())
775 bool targets_dst = false;
777 for (auto& input : output.get_target_inputs())
779 if (input.get_node() == &dst)
788 result.push_back(output);
795 static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node,
796 std::deque<std::shared_ptr<ngraph::Node>>& path,
797 std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
798 ngraph::NodeVector& cycle_nodes)
800 path.push_back(node);
801 path_set.insert(node);
802 for (size_t i = 0; i < node->inputs().size(); i++)
804 auto arg = node->get_input_node_shared_ptr(i);
805 if (path_set.find(arg) != path_set.end())
809 cycle_nodes.push_back(it);
812 cycle_nodes.push_back(arg);
815 if (check_for_cycles_bkwd(arg, path, path_set, cycle_nodes))
820 path_set.erase(path.back());
825 static bool check_for_cycles_fwd(std::shared_ptr<ngraph::Node> node,
826 std::deque<std::shared_ptr<ngraph::Node>>& path,
827 std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
828 ngraph::NodeVector& cycle_nodes)
830 path.push_back(node);
831 path_set.insert(node);
832 for (auto& arg : node->get_users())
834 if (path_set.find(arg) != path_set.end())
838 cycle_nodes.push_back(it);
841 cycle_nodes.push_back(arg);
844 if (check_for_cycles_fwd(arg, path, path_set, cycle_nodes))
849 path_set.erase(path.back());
854 bool ngraph::check_for_cycles(const ngraph::Function* func,
855 ngraph::NodeVector& cycle_nodes,
858 for (auto res : func->get_results())
860 std::deque<std::shared_ptr<Node>> path;
861 // mirror of path stack for faster cycle check
862 std::unordered_set<std::shared_ptr<Node>> path_set;
863 if (check_for_cycles_bkwd(res, path, path_set, cycle_nodes))
865 is_bkwd_cycle = true;
870 for (auto param : func->get_parameters())
872 std::deque<std::shared_ptr<Node>> path;
873 // mirror of path stack for faster cycle check
874 std::unordered_set<std::shared_ptr<Node>> path_set;
875 if (check_for_cycles_fwd(param, path, path_set, cycle_nodes))
877 is_bkwd_cycle = false;
885 bool ngraph::replace_output_update_name(Output<Node> output, const Output<Node>& replacement)
887 bool has_result_output = false;
888 for (auto& target_input : output.get_target_inputs())
890 if (is_type<op::Result>(target_input.get_node()))
892 // ignore trivial elimination
893 has_result_output = true;
894 if (is_type<ngraph::op::Parameter>(replacement.get_node()))
901 if (!has_result_output || replacement.get_node()->get_users().size() == 1)
903 if (has_result_output && !is_type<ngraph::op::Parameter>(replacement.get_node()))
905 replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
906 // Update output tensor name
907 replacement.get_tensor().set_name(output.get_node()->get_friendly_name());
909 output.replace(replacement);
910 copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
911 replacement.get_node_shared_ptr());
917 bool ngraph::replace_node_update_name(std::shared_ptr<Node> target,
918 std::shared_ptr<Node> replacement)
920 for (auto& output : target->output(0).get_target_inputs())
922 if (as_type<ngraph::op::Parameter>(replacement->input_value(0).get_node()) &&
923 as_type<op::Result>(output.get_node()))
928 replace_node(target, replacement);
929 replacement->set_friendly_name(target->get_friendly_name());
930 copy_runtime_info(target, replacement);