2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.partial_infer.eltwise import eltwise_infer
22 from mo.graph.graph import Node
23 from mo.utils.unittest.graph import build_graph
25 nodes_attributes = {'node_1': {'value': 2, 'kind': 'data'},
26 'node_2': {'value': 3, 'kind': 'data'},
27 'eltw_1': {'type': 'Eltwise', 'kind': 'op'},
28 'node_3': {'value': None, 'kind': 'data'},
29 'op_output': { 'kind': 'op', 'op': 'OpOutput'},
33 class TestEltwiseInfer(unittest.TestCase):
34 def test_eltwise_infer_max(self):
35 graph = build_graph(nodes_attributes,
36 [('node_1', 'eltw_1'),
39 ('node_3', 'op_output')
41 {'node_3': {'shape': None},
42 'node_1': {'shape': np.array([1, 3, 256, 256])},
43 'node_2': {'shape': np.array([1, 3, 256, 256])},
47 graph.graph['layout'] = 'NCHW'
49 eltwise_node = Node(graph, 'eltw_1')
51 eltwise_infer(eltwise_node, lambda a, b: np.maximum(a, b))
52 exp_shape = np.array([1, 3, 256, 256])
54 res_shape = graph.node['node_3']['shape']
55 res_value = eltwise_node.out_node().value
56 for i in range(0, len(exp_shape)):
57 self.assertEqual(exp_shape[i], res_shape[i])
59 self.assertEqual(exp_value, res_value)
61 def test_eltwise_infer_sum(self):
62 graph = build_graph(nodes_attributes,
63 [('node_1', 'eltw_1'),
66 ('node_3', 'op_output')
68 {'node_3': {'shape': None},
69 'node_1': {'shape': np.array([1, 3, 256, 256])},
70 'node_2': {'shape': np.array([1, 3, 256, 256])}
72 graph.graph['layout'] = 'NCHW'
73 eltwise_node = Node(graph, 'eltw_1')
75 eltwise_infer(eltwise_node, lambda a, b: a + b)
76 exp_shape = np.array([1, 3, 256, 256])
78 res_shape = graph.node['node_3']['shape']
79 res_value = eltwise_node.out_node().value
80 for i in range(0, len(exp_shape)):
81 self.assertEqual(exp_shape[i], res_shape[i])
83 self.assertEqual(exp_value, res_value)
85 def test_eltwise_infer_mul(self):
86 graph = build_graph(nodes_attributes,
87 [('node_1', 'eltw_1'),
90 ('node_3', 'op_output')
92 {'node_3': {'shape': None},
93 'node_1': {'shape': np.array([1, 3, 256, 256])},
94 'node_2': {'shape': np.array([1, 3, 256, 256])}
96 graph.graph['layout'] = 'NCHW'
97 eltwise_node = Node(graph, 'eltw_1')
99 eltwise_infer(eltwise_node, lambda a, b: a * b)
100 exp_shape = np.array([1, 3, 256, 256])
102 res_shape = graph.node['node_3']['shape']
103 res_value = eltwise_node.out_node().value
104 for i in range(0, len(exp_shape)):
105 self.assertEqual(exp_shape[i], res_shape[i])
107 self.assertEqual(exp_value, res_value)
109 def test_eltwise_infer_none_val(self):
110 graph = build_graph(nodes_attributes,
111 [('node_1', 'eltw_1'),
112 ('node_2', 'eltw_1'),
113 ('eltw_1', 'node_3'),
114 ('node_3', 'op_output')
116 {'node_3': {'shape': None},
117 'node_1': {'shape': np.array([1, 3, 256, 256]), 'value': None},
118 'node_2': {'shape': np.array([1, 3, 256, 256])}
120 graph.graph['layout'] = 'NCHW'
121 eltwise_node = Node(graph, 'eltw_1')
123 eltwise_infer(eltwise_node, lambda a, b: a * b)
124 exp_shape = np.array([1, 3, 256, 256])
125 res_shape = graph.node['node_3']['shape']
126 res_value = eltwise_node.out_node().value
127 for i in range(0, len(exp_shape)):
128 self.assertEqual(exp_shape[i], res_shape[i])
130 self.assertIsNone(res_value)
132 def test_eltwise_infer_none_min_max(self):
133 graph = build_graph(nodes_attributes,
134 [('node_1', 'eltw_1'),
135 ('node_2', 'eltw_1'),
136 ('eltw_1', 'node_3'),
137 ('node_3', 'op_output')
139 {'node_3': {'shape': None},
140 'node_1': {'shape': np.array([1, 3, 257, 256])},
141 'node_2': {'shape': np.array([1, 3, 256, 257])}
143 graph.graph['layout'] = 'NCHW'
144 eltwise_node = Node(graph, 'eltw_1')
146 eltwise_infer(eltwise_node)
147 exp_shape = np.array([1, 3, -1, -1])
148 res_shape = graph.node['node_3']['shape']
149 for i in range(0, len(exp_shape)):
150 self.assertEqual(exp_shape[i], res_shape[i])