2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.graph.graph import Node
23 from mo.ops.crop import Crop
24 from mo.utils.unittest.graph import build_graph
27 class TestCropPartialInfer(unittest.TestCase):
29 def _create_graph_type1():
30 nodes_attributes = {'crop_input': {'shape': None, 'value': None, 'kind': 'data'},
31 'crop_node': {'type': 'Crop', 'kind': 'op'},
32 'crop_output': {'shape': None, 'value': None, 'kind': 'data'}
34 return build_graph(nodes_attributes,
36 ('crop_input', 'crop_node'), ('crop_node', 'crop_output')
39 'crop_input': {'shape': int64_array([1, 3, 224, 224])},
40 'crop_node': {'axis': int64_array([2, 3]),
41 'crop_begin': int64_array([10, 15]),
42 'crop_end': int64_array([10, 15])
47 def _create_graph_type2():
48 nodes_attributes = {'crop_input': {'shape': None, 'value': None, 'kind': 'data'},
49 'crop_node': {'type': 'Crop', 'kind': 'op'},
50 'crop_output': {'shape': None, 'value': None, 'kind': 'data'}
52 return build_graph(nodes_attributes,
54 ('crop_input', 'crop_node'), ('crop_node', 'crop_output')
57 'crop_input': {'shape': int64_array([1, 3, 224, 224])},
58 'crop_node': {'axis': int64_array([2, 3]), 'dim': int64_array([100, 150])},
62 def _create_graph_type3():
63 nodes_attributes = {'crop_input': {'shape': None, 'value': None, 'kind': 'data'},
64 'crop_input2': {'shape': None, 'value': None, 'kind': 'data'},
65 'crop_node': {'type': 'Crop', 'kind': 'op'},
66 'crop_output': {'shape': None, 'value': None, 'kind': 'data'}
68 return build_graph(nodes_attributes,
70 ('crop_input', 'crop_node'), ('crop_input2', 'crop_node'), ('crop_node', 'crop_output')
73 'crop_input': {'shape': int64_array([1, 3, 224, 224])},
74 'crop_input2': {'shape': int64_array([1, 3, 100, 150])},
75 'crop_node': {'axis': 2, 'offset': int64_array([10, 15])},
78 def test_crop_type1_infer(self):
79 graph = self._create_graph_type1()
81 crop_node = Node(graph, 'crop_node')
84 exp_shape = int64_array([1, 3, 204, 194])
85 res_shape = graph.node['crop_output']['shape']
87 self.assertTrue(np.array_equal(exp_shape, res_shape),
88 'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
90 def test_crop_type1_infer_neg1(self):
91 graph = self._create_graph_type1()
93 crop_node = Node(graph, 'crop_node')
94 crop_node['axis'] = None
97 self.assertIsNone(crop_node.out_node().shape)
99 def test_crop_type1_infer_neg2(self):
100 graph = self._create_graph_type1()
102 crop_node = Node(graph, 'crop_node')
103 crop_node['crop_begin'] = int64_array([1, 2, 3])
105 Crop.infer(crop_node)
106 self.assertIsNone(crop_node.out_node().shape)
108 def test_crop_type2_infer(self):
109 graph = self._create_graph_type2()
111 crop_node = Node(graph, 'crop_node')
112 Crop.infer(crop_node)
114 exp_shape = int64_array([1, 3, 100, 150])
115 res_shape = graph.node['crop_output']['shape']
117 self.assertTrue(np.array_equal(exp_shape, res_shape),
118 'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
120 def test_crop_type2_infer_neg1(self):
121 graph = self._create_graph_type2()
123 crop_node = Node(graph, 'crop_node')
124 crop_node['dim'] = int64_array([1, 2, 3])
126 Crop.infer(crop_node)
127 self.assertIsNone(crop_node.out_node().shape)
129 def test_crop_type2_infer_neg2(self):
130 graph = self._create_graph_type2()
132 crop_node = Node(graph, 'crop_node')
133 crop_node['dim'] = None
134 crop_node['crop_begin'] = None
136 Crop.infer(crop_node)
137 self.assertIsNone(crop_node.out_node().shape)
139 def test_crop_type3_infer(self):
140 graph = self._create_graph_type3()
142 crop_node = Node(graph, 'crop_node')
143 Crop.infer(crop_node)
145 exp_shape = int64_array([1, 3, 100, 150])
146 res_shape = graph.node['crop_output']['shape']
148 self.assertTrue(np.array_equal(exp_shape, res_shape),
149 'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
151 def test_crop_type3_infer_neg1(self):
152 graph = self._create_graph_type3()
154 crop_node = Node(graph, 'crop_node')
155 crop_input2 = Node(graph, 'crop_input2')
156 crop_input2.shape = None
158 Crop.infer(crop_node)
159 self.assertIsNone(crop_node.out_node().shape)
161 def test_crop_type3_infer_neg2(self):
162 graph = self._create_graph_type3()
164 crop_node = Node(graph, 'crop_node')
165 crop_node['axis'] = None
167 Crop.infer(crop_node)
168 self.assertIsNone(crop_node.out_node().shape)
170 def test_crop_type3_infer_neg3(self):
171 graph = self._create_graph_type3()
173 crop_node = Node(graph, 'crop_node')
174 crop_node['offset'] = None
176 Crop.infer(crop_node)
177 self.assertIsNone(crop_node.out_node().shape)
179 def test_crop_type3_infer_neg4(self):
180 graph = self._create_graph_type3()
182 crop_node = Node(graph, 'crop_node')
183 crop_input2 = Node(graph, 'crop_input2')
184 crop_input2.shape = int64_array([1, 4, 423, 563])
186 Crop.infer(crop_node)
187 self.assertIsNone(crop_node.out_node().shape)