Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / ops / switch_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18 from unittest.mock import Mock, call
19
20 import numpy as np
21
22 from extensions.ops.switch import Switch
23 from mo.graph.graph import Node
24 from mo.utils.unittest.graph import build_graph_with_edge_attrs, compare_graphs, build_graph_with_attrs
25
26
27 class TestSwitch(unittest.TestCase):
28     def test_switch_infer_with_condition(self):
29         nodes = [
30             ('tensor', {'value': np.zeros((3, 3)), 'kind': 'data', 'executable': True, 'shape': np.array([3, 3])}),
31             ('pred_id', {'value': True, 'kind': 'data', 'executable': True}),
32             ('switch', {'type': 'Switch', 'kind': 'op', 'op': 'Switch'}),
33             ('switch_data_0', {'value': None, 'kind': 'data', 'executable': True}),
34             ('switch_data_1', {'value': None, 'kind': 'data', 'executable': True})
35         ]
36         edges = [
37             ('tensor', 'switch', {'in': 0}),
38             ('pred_id', 'switch', {'in': 1}),
39             ('switch', 'switch_data_0', {'out': 0}),
40             ('switch', 'switch_data_1', {'out': 1})
41         ]
42         graph = build_graph_with_attrs(nodes_with_attrs=nodes, edges_with_attrs=edges)
43
44         # We should propagate shapes and values
45         graph_ref = build_graph_with_attrs(nodes_with_attrs=nodes,
46                                            edges_with_attrs=edges,
47                                            update_nodes_attributes=[('switch_data_0', {'shape': np.array([3, 3]),
48                                                                                        'value': np.zeros((3,3))}),
49                                                                     ('switch_data_1', {'shape': np.array([3, 3]),
50                                                                                        'value': np.zeros((3,3))})])
51
52         tested_class = Switch(graph=graph, attrs={})
53
54         node = Node(graph, 'switch')
55         tested_class.infer(node)
56
57         (flag, resp) = compare_graphs(graph, graph_ref, 'switch_data_0', check_op_attrs=True)
58         self.assertTrue(flag, resp)
59
60     def test_switch_infer_no_condition(self):
61         nodes = [
62             ('tensor', {'value': None, 'kind': 'data', 'executable': True, 'shape': np.array([1, 2, 1])}),
63             ('pred_id', {'value': None, 'kind': 'data', 'executable': True}),
64             ('switch', {'type': 'Switch', 'kind': 'op', 'op': 'Switch'}),
65             ('switch_data_0', {'value': None, 'kind': 'data', 'executable': True}),
66             ('switch_data_1', {'value': None, 'kind': 'data', 'executable': True})
67         ]
68         edges = [
69             ('tensor', 'switch', {'in': 0}),
70             ('pred_id', 'switch', {'in': 1}),
71             ('switch', 'switch_data_0', {'out': 0}),
72             ('switch', 'switch_data_1', {'out': 1})
73         ]
74         graph = build_graph_with_attrs(nodes_with_attrs=nodes, edges_with_attrs=edges)
75
76         # We should propagate only shapes
77         graph_ref = build_graph_with_attrs(nodes_with_attrs=nodes,
78                                            edges_with_attrs=edges,
79                                            update_nodes_attributes=[('switch_data_0', {'shape': np.array([1, 2, 1])}),
80                                                                     ('switch_data_1', {'shape': np.array([1, 2, 1])})])
81
82         tested_class = Switch(graph=graph, attrs={})
83
84         node = Node(graph, 'switch')
85         tested_class.infer(node)
86
87         (flag, resp) = compare_graphs(graph, graph_ref, 'switch_data_0', check_op_attrs=True)
88         self.assertTrue(flag, resp)
89
90     def test_switch_cf_infer_no_condition(self):
91         me_mock = Mock()
92         nodes = {
93             'tensor': {'value': True, 'kind': 'data', 'executable': True},
94             'pred_id': {'value': None, 'kind': 'data', 'executable': True},
95             'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
96             'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
97             'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
98         }
99         edges = [
100             ('tensor', 'switch', {'in': 0}),
101             ('pred_id', 'switch', {'in': 1}),
102             ('switch', 'switch_data_0', {'out': 0}),
103             ('switch', 'switch_data_1', {'out': 1})
104         ]
105         graph = build_graph_with_edge_attrs(nodes, edges)
106
107         tested_class = Switch(graph=graph, attrs={})
108         node = Node(graph, 'switch')
109         tested_class.control_flow_infer(node, True, me_mock)
110         # In this case we should mark all ports as executable
111         me_mock.assert_has_calls([call('switch_data_0', True), call('switch_data_1', True)], any_order=True)
112
113     def test_switch_cf_true_both_ports(self):
114         me_mock = Mock()
115         nodes = {
116             'tensor': {'value': True, 'kind': 'data', 'executable': True},
117             'pred_id': {'value': np.array(True), 'kind': 'data', 'executable': True},
118             'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
119             'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
120             'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
121         }
122         edges = [
123             ('tensor', 'switch', {'in': 0}),
124             ('pred_id', 'switch', {'in': 1}),
125             ('switch', 'switch_data_0', {'out': 0}),
126             ('switch', 'switch_data_1', {'out': 1})
127         ]
128         graph = build_graph_with_edge_attrs(nodes, edges)
129         tested_class = Switch(graph=graph, attrs={})
130         node = Node(graph, 'switch')
131         tested_class.control_flow_infer(node, True, me_mock)
132         me_mock.assert_has_calls([call('switch_data_0', False), call('switch_data_1', True)], any_order=True)
133
134     def test_switch_cf_false_both_ports(self):
135         me_mock = Mock()
136
137         nodes = {
138             'tensor': {'value': True, 'kind': 'data', 'executable': True},
139             'pred_id': {'value': np.array(False), 'kind': 'data', 'executable': True},
140             'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
141             'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
142             'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
143         }
144         edges = [
145             ('tensor', 'switch', {'in': 0}),
146             ('pred_id', 'switch', {'in': 1}),
147             ('switch', 'switch_data_0', {'out': 0}),
148             ('switch', 'switch_data_1', {'out': 1})
149         ]
150         graph = build_graph_with_edge_attrs(nodes, edges)
151         tested_class = Switch(graph=graph, attrs={})
152         node = Node(graph, 'switch')
153         tested_class.control_flow_infer(node, True, me_mock)
154         me_mock.assert_has_calls([call('switch_data_0', True), call('switch_data_1', False)], any_order=True)
155
156     def test_switch_cf_true_one_exec_port(self):
157         me_mock = Mock()
158
159         nodes = {
160             'tensor': {'value': True, 'kind': 'data', 'executable': True},
161             'pred_id': {'value': np.array(True), 'kind': 'data', 'executable': True},
162             'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
163             'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
164         }
165         edges = [
166             ('tensor', 'switch', {'in': 0}),
167             ('pred_id', 'switch', {'in': 1}),
168             ('switch', 'switch_data_1', {'out': 1})
169         ]
170         graph = build_graph_with_edge_attrs(nodes, edges)
171         tested_class = Switch(graph=graph, attrs={})
172         node = Node(graph, 'switch')
173         tested_class.control_flow_infer(node, True, me_mock)
174         me_mock.assert_has_calls([call('switch_data_1', True)], any_order=True)
175
176     def test_switch_cf_false_one_exec_port(self):
177         me_mock = Mock()
178
179         nodes = {
180             'tensor': {'value': True, 'kind': 'data', 'executable': True},
181             'pred_id': {'value': np.array(False), 'kind': 'data', 'executable': True},
182             'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
183             'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
184         }
185         edges = [
186             ('tensor', 'switch', {'in': 0}),
187             ('pred_id', 'switch', {'in': 1}),
188             ('switch', 'switch_data_0', {'out': 0}),
189         ]
190         graph = build_graph_with_edge_attrs(nodes, edges)
191         tested_class = Switch(graph=graph, attrs={})
192         node = Node(graph, 'switch')
193         tested_class.control_flow_infer(node, True, me_mock)
194         me_mock.assert_has_calls([call('switch_data_0', True)], any_order=True)
195
196     def test_switch_cf_true_no_exec(self):
197         me_mock = Mock()
198
199         nodes = {
200             'tensor': {'value': True, 'kind': 'data', 'executable': True},
201             'pred_id': {'value':  np.array(True), 'kind': 'data', 'executable': True},
202             'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
203             'switch_data_0': {'value': None, 'kind': 'data', 'executable': True}
204         }
205         edges = [
206             ('tensor', 'switch', {'in': 0}),
207             ('pred_id', 'switch', {'in': 1}),
208             ('switch', 'switch_data_0', {'out': 0}),
209         ]
210         graph = build_graph_with_edge_attrs(nodes, edges)
211         tested_class = Switch(graph=graph, attrs={})
212         node = Node(graph, 'switch')
213         tested_class.control_flow_infer(node, True, me_mock)
214         me_mock.assert_has_calls([call('switch_data_0', False)], any_order=True)
215
216     def test_switch_cf_false_no_exec(self):
217         me_mock = Mock()
218
219         nodes = {
220             'tensor': {'value': True, 'kind': 'data', 'executable': True},
221             'pred_id': {'value': np.array(False), 'kind': 'data', 'executable': True},
222             'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
223             'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
224         }
225         edges = [
226             ('tensor', 'switch', {'in': 0}),
227             ('pred_id', 'switch', {'in': 1}),
228             ('switch', 'switch_data_1', {'out': 1})
229         ]
230         graph = build_graph_with_edge_attrs(nodes, edges)
231         tested_class = Switch(graph=graph, attrs={})
232         node = Node(graph, 'switch')
233         tested_class.control_flow_infer(node, True, me_mock)
234         me_mock.assert_has_calls([call('switch_data_1', False)], any_order=True)