Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / crop_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.graph.graph import Node
23 from mo.ops.crop import Crop
24 from mo.utils.unittest.graph import build_graph
25
26
27 class TestCropPartialInfer(unittest.TestCase):
28     @staticmethod
29     def _create_graph_type1():
30         nodes_attributes = {'crop_input': {'shape': None, 'value': None, 'kind': 'data'},
31                             'crop_node': {'type': 'Crop', 'kind': 'op'},
32                             'crop_output': {'shape': None, 'value': None, 'kind': 'data'}
33                             }
34         return build_graph(nodes_attributes,
35                            [
36                                ('crop_input', 'crop_node'), ('crop_node', 'crop_output')
37                            ],
38                            {
39                                'crop_input': {'shape': int64_array([1, 3, 224, 224])},
40                                'crop_node': {'axis': int64_array([2, 3]),
41                                              'crop_begin': int64_array([10, 15]),
42                                              'crop_end': int64_array([10, 15])
43                                              },
44                            })
45
46     @staticmethod
47     def _create_graph_type2():
48         nodes_attributes = {'crop_input': {'shape': None, 'value': None, 'kind': 'data'},
49                             'crop_node': {'type': 'Crop', 'kind': 'op'},
50                             'crop_output': {'shape': None, 'value': None, 'kind': 'data'}
51                             }
52         return build_graph(nodes_attributes,
53                            [
54                                ('crop_input', 'crop_node'), ('crop_node', 'crop_output')
55                            ],
56                            {
57                                'crop_input': {'shape': int64_array([1, 3, 224, 224])},
58                                'crop_node': {'axis': int64_array([2, 3]), 'dim': int64_array([100, 150])},
59                            })
60
61     @staticmethod
62     def _create_graph_type3():
63         nodes_attributes = {'crop_input': {'shape': None, 'value': None, 'kind': 'data'},
64                             'crop_input2': {'shape': None, 'value': None, 'kind': 'data'},
65                             'crop_node': {'type': 'Crop', 'kind': 'op'},
66                             'crop_output': {'shape': None, 'value': None, 'kind': 'data'}
67                             }
68         return build_graph(nodes_attributes,
69                            [
70                                ('crop_input', 'crop_node'), ('crop_input2', 'crop_node'), ('crop_node', 'crop_output')
71                            ],
72                            {
73                                'crop_input': {'shape': int64_array([1, 3, 224, 224])},
74                                'crop_input2': {'shape': int64_array([1, 3, 100, 150])},
75                                'crop_node': {'axis': 2, 'offset': int64_array([10, 15])},
76                            })
77
78     def test_crop_type1_infer(self):
79         graph = self._create_graph_type1()
80
81         crop_node = Node(graph, 'crop_node')
82         Crop.infer(crop_node)
83
84         exp_shape = int64_array([1, 3, 204, 194])
85         res_shape = graph.node['crop_output']['shape']
86
87         self.assertTrue(np.array_equal(exp_shape, res_shape),
88                         'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
89
90     def test_crop_type1_infer_neg1(self):
91         graph = self._create_graph_type1()
92
93         crop_node = Node(graph, 'crop_node')
94         crop_node['axis'] = None
95
96         Crop.infer(crop_node)
97         self.assertIsNone(crop_node.out_node().shape)
98
99     def test_crop_type1_infer_neg2(self):
100         graph = self._create_graph_type1()
101
102         crop_node = Node(graph, 'crop_node')
103         crop_node['crop_begin'] = int64_array([1, 2, 3])
104
105         Crop.infer(crop_node)
106         self.assertIsNone(crop_node.out_node().shape)
107
108     def test_crop_type2_infer(self):
109         graph = self._create_graph_type2()
110
111         crop_node = Node(graph, 'crop_node')
112         Crop.infer(crop_node)
113
114         exp_shape = int64_array([1, 3, 100, 150])
115         res_shape = graph.node['crop_output']['shape']
116
117         self.assertTrue(np.array_equal(exp_shape, res_shape),
118                         'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
119
120     def test_crop_type2_infer_neg1(self):
121         graph = self._create_graph_type2()
122
123         crop_node = Node(graph, 'crop_node')
124         crop_node['dim'] = int64_array([1, 2, 3])
125
126         Crop.infer(crop_node)
127         self.assertIsNone(crop_node.out_node().shape)
128
129     def test_crop_type2_infer_neg2(self):
130         graph = self._create_graph_type2()
131
132         crop_node = Node(graph, 'crop_node')
133         crop_node['dim'] = None
134         crop_node['crop_begin'] = None
135
136         Crop.infer(crop_node)
137         self.assertIsNone(crop_node.out_node().shape)
138
139     def test_crop_type3_infer(self):
140         graph = self._create_graph_type3()
141
142         crop_node = Node(graph, 'crop_node')
143         Crop.infer(crop_node)
144
145         exp_shape = int64_array([1, 3, 100, 150])
146         res_shape = graph.node['crop_output']['shape']
147
148         self.assertTrue(np.array_equal(exp_shape, res_shape),
149                         'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
150
151     def test_crop_type3_infer_neg1(self):
152         graph = self._create_graph_type3()
153
154         crop_node = Node(graph, 'crop_node')
155         crop_input2 = Node(graph, 'crop_input2')
156         crop_input2.shape = None
157
158         Crop.infer(crop_node)
159         self.assertIsNone(crop_node.out_node().shape)
160
161     def test_crop_type3_infer_neg2(self):
162         graph = self._create_graph_type3()
163
164         crop_node = Node(graph, 'crop_node')
165         crop_node['axis'] = None
166
167         Crop.infer(crop_node)
168         self.assertIsNone(crop_node.out_node().shape)
169
170     def test_crop_type3_infer_neg3(self):
171         graph = self._create_graph_type3()
172
173         crop_node = Node(graph, 'crop_node')
174         crop_node['offset'] = None
175
176         Crop.infer(crop_node)
177         self.assertIsNone(crop_node.out_node().shape)
178
179     def test_crop_type3_infer_neg4(self):
180         graph = self._create_graph_type3()
181
182         crop_node = Node(graph, 'crop_node')
183         crop_input2 = Node(graph, 'crop_input2')
184         crop_input2.shape = int64_array([1, 4, 423, 563])
185
186         Crop.infer(crop_node)
187         self.assertIsNone(crop_node.out_node().shape)