2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
19 from mo.utils.unittest.graph import build_graph_with_attrs, compare_graphs
22 class CreateConstNodesReplacementTest(unittest.TestCase):
24 ('data_node', {'kind': 'data', 'shape': None, 'value': None}),
25 ('next_node', {'kind': 'op'}),
28 ('data_node', 'next_node')
32 ('const', {'kind': 'op', 'op': 'Const'}),
33 ('const_data', {'kind': 'data'})
36 ('const', 'data_node'),
37 ('const_data', 'const')
40 def test_one_node(self):
41 """We should add Const node and data node."""
42 shape = np.array([2, 3, 4])
43 data = np.zeros(shape)
44 graph = build_graph_with_attrs(
45 nodes_with_attrs=self.nodes,
46 edges_with_attrs=self.edges,
47 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})]
49 graph_ref = build_graph_with_attrs(
50 nodes_with_attrs=self.nodes + self.new_nodes,
51 edges_with_attrs=self.edges + self.new_edges,
52 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}),
53 ('const_data', {'shape': shape, 'value': data})]
55 tested_pattern = CreateConstNodesReplacement()
56 tested_pattern.find_and_replace_pattern(graph)
57 (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
58 self.assertTrue(flag, resp)
60 def test_one_bin_node(self):
61 """Nothing should happen."""
62 shape = np.array([2, 3, 4])
63 data = np.zeros(shape)
64 graph = build_graph_with_attrs(
65 nodes_with_attrs=self.nodes,
66 edges_with_attrs=self.edges,
67 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})],
68 update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}},
70 tested_pattern = CreateConstNodesReplacement()
71 tested_pattern.find_and_replace_pattern(graph)
72 (flag, resp) = compare_graphs(graph, graph, last_node='next_node')
73 self.assertTrue(flag, resp)
75 def test_force_precision_parameter(self):
77 shape = np.array([2, 3, 4])
78 data = np.zeros(shape)
79 graph = build_graph_with_attrs(
80 nodes_with_attrs=self.nodes,
81 edges_with_attrs=self.edges,
82 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data, 'force_precision': precision})]
84 graph_ref = build_graph_with_attrs(
85 nodes_with_attrs=self.nodes + self.new_nodes,
86 edges_with_attrs=self.edges + self.new_edges,
87 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}),
88 ('const_data', {'shape': shape, 'value': data, 'force_precision': precision}),
89 ('const', {'force_precision': precision})]
91 tested_pattern = CreateConstNodesReplacement()
92 tested_pattern.find_and_replace_pattern(graph)
93 (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
94 self.assertTrue(flag, resp)
96 #check that force precision was added to data and Const nodes
97 force_precision_const_node = graph.nodes['data_node_const']['force_precision']
98 force_precision_new_data = graph.nodes['data_node_copy_']['force_precision']
99 self.assertEqual(force_precision_const_node, precision)
100 self.assertEqual(force_precision_new_data, precision)
102 def test_two_nodes_with_bin(self):
103 """Test case for data node with 2 consumers with bin edge attr.
104 Nothing should happened."""
105 shape = np.array([2, 3, 4])
106 data = np.zeros(shape)
107 graph = build_graph_with_attrs(
108 nodes_with_attrs=self.nodes + [('next_node_2', {'kind': 'op'})],
109 edges_with_attrs=self.edges + [('data_node', 'next_node_2')],
110 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})],
111 update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}, ('data_node', 'next_node_2', 0): {'bin': 0}},
113 tested_pattern = CreateConstNodesReplacement()
114 tested_pattern.find_and_replace_pattern(graph)
115 (flag, resp) = compare_graphs(graph, graph, last_node='next_node')
116 self.assertTrue(flag, resp)
118 def test_two_nodes_one_bin(self):
119 """Test case for two output nodes, one with 'bin' parameter, other without."""
120 shape = np.array([2, 3, 4])
121 data = np.zeros(shape)
122 graph = build_graph_with_attrs(
123 nodes_with_attrs=self.nodes + [('next_node_2', {'kind': 'op'})],
124 edges_with_attrs=self.edges + [('data_node', 'next_node_2')],
125 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})],
126 update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}},
128 graph_ref = build_graph_with_attrs(
129 nodes_with_attrs=self.nodes + self.new_nodes + [('next_node_2', {'kind': 'op'})],
130 edges_with_attrs=self.edges + self.new_edges + [('data_node', 'next_node_2')],
131 update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}),
132 ('const_data', {'shape': shape, 'value': data})]
134 tested_pattern = CreateConstNodesReplacement()
135 tested_pattern.find_and_replace_pattern(graph)
136 (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
137 self.assertTrue(flag, resp)