Correct removing nodes from graph and add test for ConstToResult transform (#1084)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Tue, 23 Jun 2020 10:49:14 +0000 (15:49 +0500)
committerGitHub <noreply@github.com>
Tue, 23 Jun 2020 10:49:14 +0000 (13:49 +0300)
Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
model-optimizer/extensions/back/SpecialNodesFinalization.py
model-optimizer/extensions/back/SpecialNodesFinalization_test.py

index 333c503521a31a361afb27afa48d8cfc7eb641fa..876372c897bd7502b99303480c8d52db15a5b91e 100644 (file)
@@ -148,7 +148,7 @@ class RemoveConstToResult(BackReplacementPattern):
             nodes_to_remove.append(const_node.id)
             nodes_to_remove.append(const_data_node.id)
 
-        graph.remove_node(nodes_to_remove)
+        graph.remove_nodes_from(nodes_to_remove)
 
 
 class NormalizeTI(BackReplacementPattern):
index c5f58885db87ddf5353030717cad3450fccf5102..676f749f5514594eace2e85c40fad1064fd0399a 100644 (file)
@@ -17,7 +17,7 @@ import unittest
 
 import numpy as np
 
-from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement
+from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement, RemoveConstToResult
 from mo.utils.ir_engine.compare_graphs import compare_graphs
 from mo.utils.unittest.graph import build_graph_with_attrs
 
@@ -112,3 +112,93 @@ class CreateConstNodesReplacementTest(unittest.TestCase):
         (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
         self.assertTrue(flag, resp)
 
+
+class RemoveConstToResultReplacementTest(unittest.TestCase):
+    def test_only_consumer(self):
+        """Result node is only consumer of Const data node"""
+        nodes = [
+            ('const_node', {'type': 'Const', 'kind': 'op'}),
+            ('const_data', {'kind': 'data'}),
+            ('result_node', {'type': 'Result', 'kind': 'op'}),
+
+            ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
+            ('placeholder_1_data', {'kind': 'data'}),
+            ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+            ('relu_1_data', {'kind': 'data'}),
+        ]
+        edges = [
+            ('const_node', 'const_data'),
+            ('const_data', 'result_node'),
+
+            ('placeholder_1', 'placeholder_1_data'),
+            ('placeholder_1_data', 'relu_1'),
+            ('relu_1', 'relu_1_data')
+        ]
+        new_nodes=[
+            ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
+            ('placeholder_1_data', {'kind': 'data'}),
+            ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+            ('relu_1_data', {'kind': 'data'}),
+        ]
+        new_edges=[
+            ('placeholder_1', 'placeholder_1_data'),
+            ('placeholder_1_data', 'relu_1'),
+            ('relu_1', 'relu_1_data')
+        ]
+
+        graph = build_graph_with_attrs(
+            nodes_with_attrs=nodes,
+            edges_with_attrs=edges,
+        )
+        graph_ref = build_graph_with_attrs(
+            nodes_with_attrs=new_nodes,
+            edges_with_attrs=new_edges,
+        )
+        tested_pattern = RemoveConstToResult()
+        tested_pattern.find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
+        self.assertTrue(flag, resp)
+        self.assertNotIn('const_node', graph.node)
+        self.assertNotIn('const_data', graph.node)
+        self.assertNotIn('result_node', graph.node)
+
+    def test_two_consumers(self):
+        """Const data node has two consumers: Result and ReLu"""
+        nodes = [
+            ('const_node', {'type': 'Const', 'kind': 'op'}),
+            ('const_data', {'kind': 'data'}),
+            ('result_node', {'type': 'Result', 'kind': 'op'}),
+            ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+            ('relu_1_data', {'kind': 'data'}),
+        ]
+        edges = [
+            ('const_node', 'const_data'),
+            ('const_data', 'result_node'),
+            ('const_data', 'relu_1'),
+            ('relu_1', 'relu_1_data')
+        ]
+        new_nodes=[
+            ('const_node', {'type': 'Const', 'kind': 'op'}),
+            ('const_data', {'kind': 'data'}),
+            ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+            ('relu_1_data', {'kind': 'data'}),
+        ]
+        new_edges=[
+            ('const_node', 'const_data'),
+            ('const_data', 'relu_1'),
+            ('relu_1', 'relu_1_data')
+        ]
+
+        graph = build_graph_with_attrs(
+            nodes_with_attrs=nodes,
+            edges_with_attrs=edges,
+        )
+        graph_ref = build_graph_with_attrs(
+            nodes_with_attrs=new_nodes,
+            edges_with_attrs=new_edges,
+        )
+        tested_pattern = RemoveConstToResult()
+        tested_pattern.find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
+        self.assertTrue(flag, resp)
+        self.assertNotIn('result_node', graph.node)