2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from unittest.mock import Mock, call
22 from extensions.ops.switch import Switch
23 from mo.graph.graph import Node
24 from mo.utils.unittest.graph import build_graph_with_edge_attrs, compare_graphs, build_graph_with_attrs
27 class TestSwitch(unittest.TestCase):
28 def test_switch_infer_with_condition(self):
30 ('tensor', {'value': np.zeros((3, 3)), 'kind': 'data', 'executable': True, 'shape': np.array([3, 3])}),
31 ('pred_id', {'value': True, 'kind': 'data', 'executable': True}),
32 ('switch', {'type': 'Switch', 'kind': 'op', 'op': 'Switch'}),
33 ('switch_data_0', {'value': None, 'kind': 'data', 'executable': True}),
34 ('switch_data_1', {'value': None, 'kind': 'data', 'executable': True})
37 ('tensor', 'switch', {'in': 0}),
38 ('pred_id', 'switch', {'in': 1}),
39 ('switch', 'switch_data_0', {'out': 0}),
40 ('switch', 'switch_data_1', {'out': 1})
42 graph = build_graph_with_attrs(nodes_with_attrs=nodes, edges_with_attrs=edges)
44 # We should propagate shapes and values
45 graph_ref = build_graph_with_attrs(nodes_with_attrs=nodes,
46 edges_with_attrs=edges,
47 update_nodes_attributes=[('switch_data_0', {'shape': np.array([3, 3]),
48 'value': np.zeros((3,3))}),
49 ('switch_data_1', {'shape': np.array([3, 3]),
50 'value': np.zeros((3,3))})])
52 tested_class = Switch(graph=graph, attrs={})
54 node = Node(graph, 'switch')
55 tested_class.infer(node)
57 (flag, resp) = compare_graphs(graph, graph_ref, 'switch_data_0', check_op_attrs=True)
58 self.assertTrue(flag, resp)
60 def test_switch_infer_no_condition(self):
62 ('tensor', {'value': None, 'kind': 'data', 'executable': True, 'shape': np.array([1, 2, 1])}),
63 ('pred_id', {'value': None, 'kind': 'data', 'executable': True}),
64 ('switch', {'type': 'Switch', 'kind': 'op', 'op': 'Switch'}),
65 ('switch_data_0', {'value': None, 'kind': 'data', 'executable': True}),
66 ('switch_data_1', {'value': None, 'kind': 'data', 'executable': True})
69 ('tensor', 'switch', {'in': 0}),
70 ('pred_id', 'switch', {'in': 1}),
71 ('switch', 'switch_data_0', {'out': 0}),
72 ('switch', 'switch_data_1', {'out': 1})
74 graph = build_graph_with_attrs(nodes_with_attrs=nodes, edges_with_attrs=edges)
76 # We should propagate only shapes
77 graph_ref = build_graph_with_attrs(nodes_with_attrs=nodes,
78 edges_with_attrs=edges,
79 update_nodes_attributes=[('switch_data_0', {'shape': np.array([1, 2, 1])}),
80 ('switch_data_1', {'shape': np.array([1, 2, 1])})])
82 tested_class = Switch(graph=graph, attrs={})
84 node = Node(graph, 'switch')
85 tested_class.infer(node)
87 (flag, resp) = compare_graphs(graph, graph_ref, 'switch_data_0', check_op_attrs=True)
88 self.assertTrue(flag, resp)
90 def test_switch_cf_infer_no_condition(self):
93 'tensor': {'value': True, 'kind': 'data', 'executable': True},
94 'pred_id': {'value': None, 'kind': 'data', 'executable': True},
95 'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
96 'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
97 'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
100 ('tensor', 'switch', {'in': 0}),
101 ('pred_id', 'switch', {'in': 1}),
102 ('switch', 'switch_data_0', {'out': 0}),
103 ('switch', 'switch_data_1', {'out': 1})
105 graph = build_graph_with_edge_attrs(nodes, edges)
107 tested_class = Switch(graph=graph, attrs={})
108 node = Node(graph, 'switch')
109 tested_class.control_flow_infer(node, True, me_mock)
110 # In this case we should mark all ports as executable
111 me_mock.assert_has_calls([call('switch_data_0', True), call('switch_data_1', True)], any_order=True)
113 def test_switch_cf_true_both_ports(self):
116 'tensor': {'value': True, 'kind': 'data', 'executable': True},
117 'pred_id': {'value': np.array(True), 'kind': 'data', 'executable': True},
118 'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
119 'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
120 'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
123 ('tensor', 'switch', {'in': 0}),
124 ('pred_id', 'switch', {'in': 1}),
125 ('switch', 'switch_data_0', {'out': 0}),
126 ('switch', 'switch_data_1', {'out': 1})
128 graph = build_graph_with_edge_attrs(nodes, edges)
129 tested_class = Switch(graph=graph, attrs={})
130 node = Node(graph, 'switch')
131 tested_class.control_flow_infer(node, True, me_mock)
132 me_mock.assert_has_calls([call('switch_data_0', False), call('switch_data_1', True)], any_order=True)
134 def test_switch_cf_false_both_ports(self):
138 'tensor': {'value': True, 'kind': 'data', 'executable': True},
139 'pred_id': {'value': np.array(False), 'kind': 'data', 'executable': True},
140 'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
141 'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
142 'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
145 ('tensor', 'switch', {'in': 0}),
146 ('pred_id', 'switch', {'in': 1}),
147 ('switch', 'switch_data_0', {'out': 0}),
148 ('switch', 'switch_data_1', {'out': 1})
150 graph = build_graph_with_edge_attrs(nodes, edges)
151 tested_class = Switch(graph=graph, attrs={})
152 node = Node(graph, 'switch')
153 tested_class.control_flow_infer(node, True, me_mock)
154 me_mock.assert_has_calls([call('switch_data_0', True), call('switch_data_1', False)], any_order=True)
156 def test_switch_cf_true_one_exec_port(self):
160 'tensor': {'value': True, 'kind': 'data', 'executable': True},
161 'pred_id': {'value': np.array(True), 'kind': 'data', 'executable': True},
162 'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
163 'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
166 ('tensor', 'switch', {'in': 0}),
167 ('pred_id', 'switch', {'in': 1}),
168 ('switch', 'switch_data_1', {'out': 1})
170 graph = build_graph_with_edge_attrs(nodes, edges)
171 tested_class = Switch(graph=graph, attrs={})
172 node = Node(graph, 'switch')
173 tested_class.control_flow_infer(node, True, me_mock)
174 me_mock.assert_has_calls([call('switch_data_1', True)], any_order=True)
176 def test_switch_cf_false_one_exec_port(self):
180 'tensor': {'value': True, 'kind': 'data', 'executable': True},
181 'pred_id': {'value': np.array(False), 'kind': 'data', 'executable': True},
182 'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
183 'switch_data_0': {'value': None, 'kind': 'data', 'executable': True},
186 ('tensor', 'switch', {'in': 0}),
187 ('pred_id', 'switch', {'in': 1}),
188 ('switch', 'switch_data_0', {'out': 0}),
190 graph = build_graph_with_edge_attrs(nodes, edges)
191 tested_class = Switch(graph=graph, attrs={})
192 node = Node(graph, 'switch')
193 tested_class.control_flow_infer(node, True, me_mock)
194 me_mock.assert_has_calls([call('switch_data_0', True)], any_order=True)
196 def test_switch_cf_true_no_exec(self):
200 'tensor': {'value': True, 'kind': 'data', 'executable': True},
201 'pred_id': {'value': np.array(True), 'kind': 'data', 'executable': True},
202 'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
203 'switch_data_0': {'value': None, 'kind': 'data', 'executable': True}
206 ('tensor', 'switch', {'in': 0}),
207 ('pred_id', 'switch', {'in': 1}),
208 ('switch', 'switch_data_0', {'out': 0}),
210 graph = build_graph_with_edge_attrs(nodes, edges)
211 tested_class = Switch(graph=graph, attrs={})
212 node = Node(graph, 'switch')
213 tested_class.control_flow_infer(node, True, me_mock)
214 me_mock.assert_has_calls([call('switch_data_0', False)], any_order=True)
216 def test_switch_cf_false_no_exec(self):
220 'tensor': {'value': True, 'kind': 'data', 'executable': True},
221 'pred_id': {'value': np.array(False), 'kind': 'data', 'executable': True},
222 'switch': {'type': 'Switch', 'kind': 'op', 'op': 'Switch'},
223 'switch_data_1': {'value': None, 'kind': 'data', 'executable': True}
226 ('tensor', 'switch', {'in': 0}),
227 ('pred_id', 'switch', {'in': 1}),
228 ('switch', 'switch_data_1', {'out': 1})
230 graph = build_graph_with_edge_attrs(nodes, edges)
231 tested_class = Switch(graph=graph, attrs={})
232 node = Node(graph, 'switch')
233 tested_class.control_flow_infer(node, True, me_mock)
234 me_mock.assert_has_calls([call('switch_data_1', False)], any_order=True)