import numpy as np
-from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement
+from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement, RemoveConstToResult
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph_with_attrs
(flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
self.assertTrue(flag, resp)
+
+class RemoveConstToResultReplacementTest(unittest.TestCase):
+ def test_only_consumer(self):
+ """Result node is only consumer of Const data node"""
+ nodes = [
+ ('const_node', {'type': 'Const', 'kind': 'op'}),
+ ('const_data', {'kind': 'data'}),
+ ('result_node', {'type': 'Result', 'kind': 'op'}),
+
+ ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
+ ('placeholder_1_data', {'kind': 'data'}),
+ ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+ ('relu_1_data', {'kind': 'data'}),
+ ]
+ edges = [
+ ('const_node', 'const_data'),
+ ('const_data', 'result_node'),
+
+ ('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'relu_1'),
+ ('relu_1', 'relu_1_data')
+ ]
+ new_nodes=[
+ ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
+ ('placeholder_1_data', {'kind': 'data'}),
+ ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+ ('relu_1_data', {'kind': 'data'}),
+ ]
+ new_edges=[
+ ('placeholder_1', 'placeholder_1_data'),
+ ('placeholder_1_data', 'relu_1'),
+ ('relu_1', 'relu_1_data')
+ ]
+
+ graph = build_graph_with_attrs(
+ nodes_with_attrs=nodes,
+ edges_with_attrs=edges,
+ )
+ graph_ref = build_graph_with_attrs(
+ nodes_with_attrs=new_nodes,
+ edges_with_attrs=new_edges,
+ )
+ tested_pattern = RemoveConstToResult()
+ tested_pattern.find_and_replace_pattern(graph)
+ (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
+ self.assertTrue(flag, resp)
+ self.assertNotIn('const_node', graph.node)
+ self.assertNotIn('const_data', graph.node)
+ self.assertNotIn('result_node', graph.node)
+
+ def test_two_consumers(self):
+ """Const data node has two consumers: Result and ReLu"""
+ nodes = [
+ ('const_node', {'type': 'Const', 'kind': 'op'}),
+ ('const_data', {'kind': 'data'}),
+ ('result_node', {'type': 'Result', 'kind': 'op'}),
+ ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+ ('relu_1_data', {'kind': 'data'}),
+ ]
+ edges = [
+ ('const_node', 'const_data'),
+ ('const_data', 'result_node'),
+ ('const_data', 'relu_1'),
+ ('relu_1', 'relu_1_data')
+ ]
+ new_nodes=[
+ ('const_node', {'type': 'Const', 'kind': 'op'}),
+ ('const_data', {'kind': 'data'}),
+ ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
+ ('relu_1_data', {'kind': 'data'}),
+ ]
+ new_edges=[
+ ('const_node', 'const_data'),
+ ('const_data', 'relu_1'),
+ ('relu_1', 'relu_1_data')
+ ]
+
+ graph = build_graph_with_attrs(
+ nodes_with_attrs=nodes,
+ edges_with_attrs=edges,
+ )
+ graph_ref = build_graph_with_attrs(
+ nodes_with_attrs=new_nodes,
+ edges_with_attrs=new_edges,
+ )
+ tested_pattern = RemoveConstToResult()
+ tested_pattern.find_and_replace_pattern(graph)
+ (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
+ self.assertTrue(flag, resp)
+ self.assertNotIn('result_node', graph.node)