2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from extensions.ops.interp import InterpOp
22 from mo.graph.graph import Node
23 from mo.utils.unittest.graph import build_graph
25 nodes_attributes = {'node_1': {'type': 'Identity', 'kind': 'op'},
26 'node_2': {'type': 'Identity', 'value': None, 'kind': 'data'},
27 'interp': {'type': 'Interp', 'kind': 'op', 'factor': None, 'parse_2nd_input': 'value'},
28 'node_3': {'type': 'Identity', 'shape': None, 'value': None, 'kind': 'data'},
29 'op_output': { 'kind': 'op', 'op': 'OpOutput'}
33 class TestInterpOp(unittest.TestCase):
34 def test_caffe_interp_infer_shrink(self):
35 graph = build_graph(nodes_attributes,
36 [('node_1', 'interp'),
38 ('node_3', 'op_output')
40 {'node_3': {'shape': None},
41 'node_1': {'shape': np.array([1, 3, 1025, 2049])},
42 'interp': {'shrink_factor': 2,
49 graph.graph['layout'] = 'NCHW'
51 interp_node = Node(graph, 'interp')
52 InterpOp.interp_infer(interp_node)
53 exp_shape = np.array([1, 3, 513, 1025])
54 res_shape = graph.node['node_3']['shape']
55 for i in range(0, len(exp_shape)):
56 self.assertEqual(exp_shape[i], res_shape[i])
58 def test_caffe_interp_infer_wh(self):
59 graph = build_graph(nodes_attributes,
60 [('node_1', 'interp'),
62 ('node_3', 'op_output')
64 {'node_3': {'shape': None},
65 'node_1': {'shape': np.array([1, 1024, 1, 1])},
66 'interp': {'width': 65,
73 graph.graph['layout'] = 'NCHW'
75 interp_node = Node(graph, 'interp')
76 InterpOp.interp_infer(interp_node)
77 exp_shape = np.array([1, 1024, 33, 65])
78 res_shape = graph.node['node_3']['shape']
79 for i in range(0, len(exp_shape)):
80 self.assertEqual(exp_shape[i], res_shape[i])
82 def test_caffe_interp_infer_zoom(self):
83 graph = build_graph(nodes_attributes,
84 [('node_1', 'interp'),
86 ('node_3', 'op_output')
88 {'node_3': {'shape': None},
89 'node_1': {'shape': np.array([1, 256, 33, 65])},
90 'interp': {'zoom_factor': 2,
97 graph.graph['layout'] = 'NCHW'
99 interp_node = Node(graph, 'interp')
100 InterpOp.interp_infer(interp_node)
101 exp_shape = np.array([1, 256, 66, 130])
102 res_shape = graph.node['node_3']['shape']
103 for i in range(0, len(exp_shape)):
104 self.assertEqual(exp_shape[i], res_shape[i])
106 def test_caffe_interp_infer_zoom_shrink(self):
107 graph = build_graph(nodes_attributes,
108 [('node_1', 'interp'),
109 ('interp', 'node_3'),
110 ('node_3', 'op_output')
112 {'node_3': {'shape': None},
113 'node_1': {'shape': np.array([1, 256, 33, 65])},
114 'interp': {'zoom_factor': 2,
121 graph.graph['layout'] = 'NCHW'
123 interp_node = Node(graph, 'interp')
124 InterpOp.interp_infer(interp_node)
125 exp_shape = np.array([1, 256, 33, 65])
126 res_shape = graph.node['node_3']['shape']
127 for i in range(0, len(exp_shape)):
128 self.assertEqual(exp_shape[i], res_shape[i])
130 def test_caffe_interp_infer_zoom_shrink_error(self):
131 graph = build_graph(nodes_attributes,
132 [('node_1', 'interp'),
133 ('interp', 'node_3'),
134 ('node_3', 'op_output')
136 {'node_3': {'shape': None},
137 'node_1': {'shape': np.array([1, 256, 33, 65])},
138 'interp': {'zoom_factor': 0,
145 graph.graph['layout'] = 'NCHW'
147 interp_node = Node(graph, 'interp')
148 InterpOp.interp_infer(interp_node)
149 self.assertIsNone(graph.node['node_3']['shape'])
151 def test_caffe_interp_infer_zoom_default(self):
152 graph = build_graph(nodes_attributes,
153 [('node_1', 'interp'),
154 ('interp', 'node_3'),
155 ('node_3', 'op_output')
157 {'node_3': {'shape': None},
158 'node_1': {'shape': np.array([1, 256, 33, 65])},
159 'interp': {'zoom_factor': 1,
167 graph.graph['layout'] = 'NCHW'
169 interp_node = Node(graph, 'interp')
170 InterpOp.interp_infer(interp_node)
171 exp_shape = np.array([1, 256, 33, 65])
172 res_shape = graph.node['node_3']['shape']
173 for i in range(0, len(exp_shape)):
174 self.assertEqual(exp_shape[i], res_shape[i])
176 def test_caffe_interp_2_blobs(self):
177 graph = build_graph(nodes_attributes,
178 [('node_1', 'interp'),
179 ('node_2', 'interp'),
180 ('interp', 'node_3'),
181 ('node_3', 'op_output')
183 {'node_3': {'shape': None},
184 'node_1': {'shape': np.array([1, 256, 33, 66])},
185 'node_2': {'shape': np.array([1, 1, 3, 6])},
186 'interp': {'zoom_factor': 1,
190 'parse_2nd_input': 'shape',
193 graph.graph['layout'] = 'NCHW'
195 interp_node = Node(graph, 'interp')
196 InterpOp.interp_infer(interp_node)
197 exp_shape = np.array([1, 256, 3, 6])
198 res_shape = graph.node['node_3']['shape']
199 for i in range(0, len(exp_shape)):
200 self.assertEqual(exp_shape[i], res_shape[i])
202 def test_tf_interp_infer_two_inputs(self):
204 graph = build_graph(nodes_attributes,
205 [('node_1', 'interp'),
206 ('node_2', 'interp'),
207 ('interp', 'node_3')],
208 {'node_1': {'shape': np.array([1, 20, 30, 100])},
209 'node_2': {'shape': np.array([2]), 'value': np.array([2, 3])}})
210 graph.graph['layout'] = 'NHWC'
211 interp_node = Node(graph, 'interp')
212 InterpOp.interp_infer(interp_node)
213 exp_shape = np.array([1, 2, 3, 100])
214 res_shape = graph.node['node_3']['shape']
215 for i in range(0, len(exp_shape)):
216 self.assertEqual(exp_shape[i], res_shape[i])
218 def test_tf_interp_infer_one_input_hw(self):
219 graph = build_graph(nodes_attributes,
220 [('node_1', 'interp'),
221 ('interp', 'node_3')],
222 {'node_1': {'shape': np.array([1, 20, 30, 100])},
223 'interp': {'height': 4, 'width': 6, 'pad_beg': 0, 'pad_end': 0, 'zoom_factor': None,
224 'shrink_factor': None}})
225 graph.graph['layout'] = 'NHWC'
226 interp_node = Node(graph, 'interp')
227 InterpOp.interp_infer(interp_node)
228 exp_shape = np.array([1, 4, 6, 100])
229 res_shape = graph.node['node_3']['shape']
230 for i in range(0, len(exp_shape)):
231 self.assertEqual(exp_shape[i], res_shape[i])