Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / EltwiseInputReshape_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from extensions.middle.EltwiseInputReshape import EltwiseInputReshape
22 from mo.middle.passes.eliminate_test import build_graph
23 from mo.middle.passes.fusing.fuse_linear_ops_test import compare_graphs
24
25 # The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
26 # dictionary with node attributes.
27 nodes_attributes = {
28     # Placeholder layers
29     'placeholder_1': {'value': None, 'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
30     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
31     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
32     'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
33
34     # Reshape layers
35     'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
36     'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
37
38     'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
39     'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
40
41     # Fake consumes layers
42     'consumer_1': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
43     'consumer_2': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
44     'consumer_3': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
45
46     # Concat
47     'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
48 }
49
50
51 class EltwiseInputReshapeTest(unittest.TestCase):
52     def test1_not_constant(self):
53         #        ,-------------->consumer3                 ,------------>consumer3
54         #   data---(new_shape1)-->consumer1      =>    data---->Reshape-->consumer1
55         #        `-(new_shape2)-->consumer2                 `-->Reshape-->consumer2
56         #
57         graph = build_graph(nodes_attributes,
58                             [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3, 1, 1]}),
59                              ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 1, 3]}),
60                              ('placeholder_1_data', 'consumer_3'),
61                              ('consumer_1', 'concat'),
62                              ('consumer_2', 'concat'),
63                              ('consumer_3', 'concat'),
64                              ],
65                             {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
66
67         graph_ref = build_graph(nodes_attributes,
68                                 [('placeholder_1_data', 'reshape_1'),
69                                  ('placeholder_1_data', 'reshape_2'),
70                                  ('placeholder_1_data', 'consumer_3'),
71                                  ('reshape_1', 'reshape_1_data'),
72                                  ('reshape_2', 'reshape_2_data'),
73                                  ('reshape_1_data', 'consumer_1'),
74                                  ('reshape_2_data', 'consumer_2'),
75                                  ('consumer_1', 'concat'),
76                                  ('consumer_2', 'concat'),
77                                  ('consumer_3', 'concat'),
78                                  ],
79                                 {'placeholder_1_data': {'shape': np.array([1, 3])},
80                                  'reshape_1': {'dim': np.array([1, 3, 1, 1])},
81                                  'reshape_1_data': {'shape': np.array([1, 3, 1, 1])},
82                                  'reshape_2': {'dim': np.array([1, 1, 3])},
83                                  'reshape_2_data': {'shape': np.array([1, 1, 3])},
84                                  }, nodes_with_edges_only=True)
85
86         pattern = EltwiseInputReshape()
87         pattern.find_and_replace_pattern(graph)
88
89         (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
90         self.assertTrue(flag, resp)
91
92     def test2_not_constant(self):
93         #        ,--------------->consumer3                ,----------->consumer3
94         #   data---(new_shape1)-->consumer1      =>    data-->Reshape-->consumer1
95         #        `-(new_shape1)-->consumer2                         `-->consumer2
96         #
97         graph = build_graph(nodes_attributes,
98                             [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3, 1, 1]}),
99                              ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 3, 1, 1]}),
100                              ('placeholder_1_data', 'consumer_3'),
101                              ('consumer_1', 'concat'),
102                              ('consumer_2', 'concat'),
103                              ('consumer_3', 'concat'),
104                              ],
105                             {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
106
107         graph_ref = build_graph(nodes_attributes,
108                                 [('placeholder_1_data', 'reshape_1'),
109                                  ('placeholder_1_data', 'consumer_3'),
110                                  ('reshape_1', 'reshape_1_data'),
111                                  ('reshape_1_data', 'consumer_1'),
112                                  ('reshape_1_data', 'consumer_2'),
113                                  ('consumer_1', 'concat'),
114                                  ('consumer_2', 'concat'),
115                                  ('consumer_3', 'concat'),
116                                  ],
117                                 {'placeholder_1_data': {'shape': np.array([1, 3])},
118                                  'reshape_1': {'dim': np.array([1, 3, 1, 1])},
119                                  'reshape_1_data': {'shape': np.array([1, 3, 1, 1])},
120                                  }, nodes_with_edges_only=True)
121
122         pattern = EltwiseInputReshape()
123         pattern.find_and_replace_pattern(graph)
124
125         (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
126         self.assertTrue(flag, resp)
127
128     def test3_constant(self):
129         #        ,--------------->consumer3            data-->consumer3
130         #   data---(new_shape1)-->consumer1      =>    data-->consumer1
131         #        `-(new_shape2)-->consumer2            data-->consumer2
132         #
133         graph = build_graph(nodes_attributes,
134                             [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3, 1, 1]}),
135                              ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 1, 3]}),
136                              ('placeholder_1_data', 'consumer_3'),
137                              ('consumer_1', 'concat'),
138                              ('consumer_2', 'concat'),
139                              ('consumer_3', 'concat'),
140                              ],
141                             {'placeholder_1_data': {'shape': np.array([1, 3]), 'value': np.ones([1, 3])}},
142                             nodes_with_edges_only=True)
143
144         graph_ref = build_graph(nodes_attributes,
145                                 [('placeholder_1_data', 'consumer_1'),
146                                  ('placeholder_2_data', 'consumer_2'),
147                                  ('placeholder_3_data', 'consumer_3'),
148                                  ('consumer_1', 'concat'),
149                                  ('consumer_2', 'concat'),
150                                  ('consumer_3', 'concat'),
151                                  ],
152                                 {'placeholder_1_data': {'shape': np.array([1, 3, 1, 1]), 'value': np.ones([1, 3, 1, 1])},
153                                  'placeholder_2_data': {'shape': np.array([1, 1, 3]), 'value': np.ones([1, 1, 3])},
154                                  'placeholder_3_data': {'shape': np.array([1, 3]), 'value': np.ones([1, 3])},
155                                  }, nodes_with_edges_only=True)
156
157         pattern = EltwiseInputReshape()
158         pattern.find_and_replace_pattern(graph)
159
160         (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
161         self.assertTrue(flag, resp)
162
163     def test4_constant(self):
164         #        ,--------------->consumer3                 ,-->consumer3
165         #   data---(new_shape1)-->consumer1      =>    data-->consumer1
166         #        `-(new_shape2)-->consumer2                 `->consumer2
167         #
168         graph = build_graph(nodes_attributes,
169                             [('placeholder_1_data', 'consumer_1', {'new_shape': [3, 1, 1]}),
170                              ('placeholder_1_data', 'consumer_2', {'new_shape': [3, 1, 1]}),
171                              ('placeholder_1_data', 'consumer_3', {'new_shape': [3, 1, 1]}),
172                              ('consumer_1', 'concat'),
173                              ('consumer_2', 'concat'),
174                              ('consumer_3', 'concat'),
175                              ],
176                             {'placeholder_1_data': {'shape': np.array([1, 3]), 'value': np.ones([1, 3])}},
177                             nodes_with_edges_only=True)
178
179         graph_ref = build_graph(nodes_attributes,
180                                 [('placeholder_1_data', 'consumer_1'),
181                                  ('placeholder_1_data', 'consumer_2'),
182                                  ('placeholder_1_data', 'consumer_3'),
183                                  ('consumer_1', 'concat'),
184                                  ('consumer_2', 'concat'),
185                                  ('consumer_3', 'concat'),
186                                  ],
187                                 {'placeholder_1_data': {'shape': np.array([3, 1, 1]), 'value': np.ones([3, 1, 1])}
188                                  }, nodes_with_edges_only=True)
189
190         pattern = EltwiseInputReshape()
191         pattern.find_and_replace_pattern(graph)
192
193         (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
194         self.assertTrue(flag, resp)
195
196     def test5_not_constant(self):
197         #        ,--------------->consumer3                ,->consumer3
198         #   data---(new_shape1)-->consumer1      =>    data----->consumer1
199         #        `-(new_shape1)-->consumer2                `-->consumer2
200         #
201         graph = build_graph(nodes_attributes,
202                             [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3]}),
203                              ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 3]}),
204                              ('placeholder_1_data', 'consumer_3'),
205                              ('consumer_1', 'concat'),
206                              ('consumer_2', 'concat'),
207                              ('consumer_3', 'concat'),
208                              ],
209                             {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
210
211         graph_ref = build_graph(nodes_attributes,
212                             [('placeholder_1_data', 'consumer_1', {'new_shape': [1, 3]}),
213                              ('placeholder_1_data', 'consumer_2', {'new_shape': [1, 3]}),
214                              ('placeholder_1_data', 'consumer_3'),
215                              ('consumer_1', 'concat'),
216                              ('consumer_2', 'concat'),
217                              ('consumer_3', 'concat'),
218                              ],
219                             {'placeholder_1_data': {'shape': np.array([1, 3])}}, nodes_with_edges_only=True)
220
221         pattern = EltwiseInputReshape()
222         pattern.find_and_replace_pattern(graph)
223
224         (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
225         self.assertTrue(flag, resp)