2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from extensions.middle.EltwiseInputReshape import EltwiseInputReshape
22 from mo.middle.passes.eliminate_test import build_graph
23 from mo.middle.passes.fusing.fuse_linear_ops_test import compare_graphs
25 # The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
26 # dictionary with node attributes.
29 'placeholder_1': {'value': None, 'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
30 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
31 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
32 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
35 'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
36 'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
38 'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
39 'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
41 # Fake consumes layers
42 'consumer_1': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
43 'consumer_2': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
44 'consumer_3': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
47 'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
51 class EltwiseInputReshapeTest(unittest.TestCase):
52 def test1_not_constant(self):
53 # ,-------------->consumer3 ,------------>consumer3
54 # data---(new_shape1)-->consumer1 => data---->Reshape-->consumer1
55 # `-(new_shape2)-->consumer2 `-->Reshape-->consumer2
57 graph = build_graph(nodes_attributes,
58 [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3, 1, 1]}),
59 ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 1, 3]}),
60 ('placeholder_1_data', 'consumer_3'),
61 ('consumer_1', 'concat'),
62 ('consumer_2', 'concat'),
63 ('consumer_3', 'concat'),
65 {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
67 graph_ref = build_graph(nodes_attributes,
68 [('placeholder_1_data', 'reshape_1'),
69 ('placeholder_1_data', 'reshape_2'),
70 ('placeholder_1_data', 'consumer_3'),
71 ('reshape_1', 'reshape_1_data'),
72 ('reshape_2', 'reshape_2_data'),
73 ('reshape_1_data', 'consumer_1'),
74 ('reshape_2_data', 'consumer_2'),
75 ('consumer_1', 'concat'),
76 ('consumer_2', 'concat'),
77 ('consumer_3', 'concat'),
79 {'placeholder_1_data': {'shape': np.array([1, 3])},
80 'reshape_1': {'dim': np.array([1, 3, 1, 1])},
81 'reshape_1_data': {'shape': np.array([1, 3, 1, 1])},
82 'reshape_2': {'dim': np.array([1, 1, 3])},
83 'reshape_2_data': {'shape': np.array([1, 1, 3])},
84 }, nodes_with_edges_only=True)
86 pattern = EltwiseInputReshape()
87 pattern.find_and_replace_pattern(graph)
89 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
90 self.assertTrue(flag, resp)
92 def test2_not_constant(self):
93 # ,--------------->consumer3 ,----------->consumer3
94 # data---(new_shape1)-->consumer1 => data-->Reshape-->consumer1
95 # `-(new_shape1)-->consumer2 `-->consumer2
97 graph = build_graph(nodes_attributes,
98 [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3, 1, 1]}),
99 ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 3, 1, 1]}),
100 ('placeholder_1_data', 'consumer_3'),
101 ('consumer_1', 'concat'),
102 ('consumer_2', 'concat'),
103 ('consumer_3', 'concat'),
105 {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
107 graph_ref = build_graph(nodes_attributes,
108 [('placeholder_1_data', 'reshape_1'),
109 ('placeholder_1_data', 'consumer_3'),
110 ('reshape_1', 'reshape_1_data'),
111 ('reshape_1_data', 'consumer_1'),
112 ('reshape_1_data', 'consumer_2'),
113 ('consumer_1', 'concat'),
114 ('consumer_2', 'concat'),
115 ('consumer_3', 'concat'),
117 {'placeholder_1_data': {'shape': np.array([1, 3])},
118 'reshape_1': {'dim': np.array([1, 3, 1, 1])},
119 'reshape_1_data': {'shape': np.array([1, 3, 1, 1])},
120 }, nodes_with_edges_only=True)
122 pattern = EltwiseInputReshape()
123 pattern.find_and_replace_pattern(graph)
125 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
126 self.assertTrue(flag, resp)
128 def test3_constant(self):
129 # ,--------------->consumer3 data-->consumer3
130 # data---(new_shape1)-->consumer1 => data-->consumer1
131 # `-(new_shape2)-->consumer2 data-->consumer2
133 graph = build_graph(nodes_attributes,
134 [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3, 1, 1]}),
135 ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 1, 3]}),
136 ('placeholder_1_data', 'consumer_3'),
137 ('consumer_1', 'concat'),
138 ('consumer_2', 'concat'),
139 ('consumer_3', 'concat'),
141 {'placeholder_1_data': {'shape': np.array([1, 3]), 'value': np.ones([1, 3])}},
142 nodes_with_edges_only=True)
144 graph_ref = build_graph(nodes_attributes,
145 [('placeholder_1_data', 'consumer_1'),
146 ('placeholder_2_data', 'consumer_2'),
147 ('placeholder_3_data', 'consumer_3'),
148 ('consumer_1', 'concat'),
149 ('consumer_2', 'concat'),
150 ('consumer_3', 'concat'),
152 {'placeholder_1_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.ones([1, 3, 1, 1])},
153 'placeholder_2_data': {'shape': np.array([1, 1, 3]), 'value': np.ones([1, 1, 3])},
154 'placeholder_3_data': {'shape': np.array([1, 3]), 'value': np.ones([1, 3])},
155 }, nodes_with_edges_only=True)
157 pattern = EltwiseInputReshape()
158 pattern.find_and_replace_pattern(graph)
160 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
161 self.assertTrue(flag, resp)
163 def test4_constant(self):
164 # ,--------------->consumer3 ,-->consumer3
165 # data---(new_shape1)-->consumer1 => data-->consumer1
166 # `-(new_shape2)-->consumer2 `->consumer2
168 graph = build_graph(nodes_attributes,
169 [('placeholder_1_data', 'consumer_1', {'new_shape': [3, 1, 1]}),
170 ('placeholder_1_data', 'consumer_2', {'new_shape': [3, 1, 1]}),
171 ('placeholder_1_data', 'consumer_3', {'new_shape': [3, 1, 1]}),
172 ('consumer_1', 'concat'),
173 ('consumer_2', 'concat'),
174 ('consumer_3', 'concat'),
176 {'placeholder_1_data': {'shape': np.array([1, 3]), 'value': np.ones([1, 3])}},
177 nodes_with_edges_only=True)
179 graph_ref = build_graph(nodes_attributes,
180 [('placeholder_1_data', 'consumer_1'),
181 ('placeholder_1_data', 'consumer_2'),
182 ('placeholder_1_data', 'consumer_3'),
183 ('consumer_1', 'concat'),
184 ('consumer_2', 'concat'),
185 ('consumer_3', 'concat'),
187 {'placeholder_1_data': {'shape': np.array([3, 1, 1]), 'value': np.ones([3, 1, 1])}
188 }, nodes_with_edges_only=True)
190 pattern = EltwiseInputReshape()
191 pattern.find_and_replace_pattern(graph)
193 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
194 self.assertTrue(flag, resp)
196 def test5_not_constant(self):
197 # ,--------------->consumer3 ,->consumer3
198 # data---(new_shape1)-->consumer1 => data----->consumer1
199 # `-(new_shape1)-->consumer2 `-->consumer2
201 graph = build_graph(nodes_attributes,
202 [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3]}),
203 ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 3]}),
204 ('placeholder_1_data', 'consumer_3'),
205 ('consumer_1', 'concat'),
206 ('consumer_2', 'concat'),
207 ('consumer_3', 'concat'),
209 {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
211 graph_ref = build_graph(nodes_attributes,
212 [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3]}),
213 ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 3]}),
214 ('placeholder_1_data', 'consumer_3'),
215 ('consumer_1', 'concat'),
216 ('consumer_2', 'concat'),
217 ('consumer_3', 'concat'),
219 {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
221 pattern = EltwiseInputReshape()
222 pattern.find_and_replace_pattern(graph)
224 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
225 self.assertTrue(flag, resp)