From 5ad1bf643d8dbe1761eb3767829cbea0f538a560 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 23 Jun 2020 15:49:14 +0500 Subject: [PATCH] Correct removing nodes from graph and add test for ConstToResult transform (#1084) Signed-off-by: Roman Kazantsev --- .../extensions/back/SpecialNodesFinalization.py | 2 +- .../back/SpecialNodesFinalization_test.py | 92 +++++++++++++++++++++- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/model-optimizer/extensions/back/SpecialNodesFinalization.py b/model-optimizer/extensions/back/SpecialNodesFinalization.py index 333c503..876372c 100644 --- a/model-optimizer/extensions/back/SpecialNodesFinalization.py +++ b/model-optimizer/extensions/back/SpecialNodesFinalization.py @@ -148,7 +148,7 @@ class RemoveConstToResult(BackReplacementPattern): nodes_to_remove.append(const_node.id) nodes_to_remove.append(const_data_node.id) - graph.remove_node(nodes_to_remove) + graph.remove_nodes_from(nodes_to_remove) class NormalizeTI(BackReplacementPattern): diff --git a/model-optimizer/extensions/back/SpecialNodesFinalization_test.py b/model-optimizer/extensions/back/SpecialNodesFinalization_test.py index c5f5888..676f749 100644 --- a/model-optimizer/extensions/back/SpecialNodesFinalization_test.py +++ b/model-optimizer/extensions/back/SpecialNodesFinalization_test.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement +from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement, RemoveConstToResult from mo.utils.ir_engine.compare_graphs import compare_graphs from mo.utils.unittest.graph import build_graph_with_attrs @@ -112,3 +112,93 @@ class CreateConstNodesReplacementTest(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp) + +class RemoveConstToResultReplacementTest(unittest.TestCase): + def test_only_consumer(self): + """Result node is only consumer of Const data node""" + nodes = [ + ('const_node', {'type': 'Const', 'kind': 'op'}), + ('const_data', {'kind': 'data'}), + ('result_node', {'type': 'Result', 'kind': 'op'}), + + ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), + ('placeholder_1_data', {'kind': 'data'}), + ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}), + ('relu_1_data', {'kind': 'data'}), + ] + edges = [ + ('const_node', 'const_data'), + ('const_data', 'result_node'), + + ('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'relu_1'), + ('relu_1', 'relu_1_data') + ] + new_nodes=[ + ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), + ('placeholder_1_data', {'kind': 'data'}), + ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}), + ('relu_1_data', {'kind': 'data'}), + ] + new_edges=[ + ('placeholder_1', 'placeholder_1_data'), + ('placeholder_1_data', 'relu_1'), + ('relu_1', 'relu_1_data') + ] + + graph = build_graph_with_attrs( + nodes_with_attrs=nodes, + edges_with_attrs=edges, + ) + graph_ref = build_graph_with_attrs( + nodes_with_attrs=new_nodes, + edges_with_attrs=new_edges, + ) + tested_pattern = RemoveConstToResult() + tested_pattern.find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data') + self.assertTrue(flag, resp) + self.assertNotIn('const_node', graph.node) + self.assertNotIn('const_data', graph.node) + self.assertNotIn('result_node', graph.node) + + def test_two_consumers(self): + """Const data node has two consumers: Result and ReLu""" + nodes = [ + ('const_node', {'type': 'Const', 'kind': 'op'}), + ('const_data', {'kind': 'data'}), + ('result_node', {'type': 'Result', 'kind': 'op'}), + ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}), + ('relu_1_data', {'kind': 'data'}), + ] + edges = [ + ('const_node', 'const_data'), + ('const_data', 'result_node'), + ('const_data', 'relu_1'), + ('relu_1', 'relu_1_data') + ] + new_nodes=[ + ('const_node', {'type': 'Const', 'kind': 'op'}), + ('const_data', {'kind': 'data'}), + ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}), + ('relu_1_data', {'kind': 'data'}), + ] + new_edges=[ + ('const_node', 'const_data'), + ('const_data', 'relu_1'), + ('relu_1', 'relu_1_data') + ] + + graph = build_graph_with_attrs( + nodes_with_attrs=nodes, + edges_with_attrs=edges, + ) + graph_ref = build_graph_with_attrs( + nodes_with_attrs=new_nodes, + edges_with_attrs=new_edges, + ) + tested_pattern = RemoveConstToResult() + tested_pattern.find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data') + self.assertTrue(flag, resp) + self.assertNotIn('result_node', graph.node) -- 2.7.4